显然, y i y_i yi 加上 c c c 可以看成是 x i x_i xi 减去 c c c。
所以就变成了 x i x_i xi 加上一个整数(可正可负)。
现将 x x x 环拆成一个长度为 2 n 2n 2n 的序列 a a a(复制一遍),把 y y y 环拆成一个长度为 n n n 的序列 b b b。
那么旋转操作就可以看成是 b b b 序列与 a a a 序列中每一个长度为 n n n 的子串匹配求值。
也就是说,求这个东西的最小值: ∑ j = 0 n − 1 ( a i + j − b j + c ) 2 sum_{j=0}^{n-1}(a_{i+j}-b_j+c)^2 ∑j=0n−1(ai+j−bj+c)2( 0 ≤ i < n 0leq i< n 0≤i<n)。
接下来推式子:(设 x x x 环的和是 s u m a suma suma, y y y 环的和是 s u m b sumb sumb, x x x 环的平方和是 p o w a powa powa, y y y 环的平方和是 p o w a powa powa)
min i = 0 n − 1 ∑ j = 0 n − 1 ( a i + j − b j + c ) 2 = min i = 0 n − 1 ∑ j = 0 n − 1 ( a i + j − b j ) 2 + c 2 + 2 ( a i + j − b j ) c = min i = 0 n − 1 [ n c 2 + ∑ j = 0 n − 1 ( a i + j − b j ) 2 + 2 c ∑ j = 0 n − 1 ( a i + j − b j ) ] = min i = 0 n − 1 [ n c 2 + ∑ j = 0 n − 1 ( a i + j 2 + b j 2 − 2 a i + j b j ) + 2 c ( ∑ j = 0 n − 1 a i + j − ∑ j = 0 n − 1 b j ) ] = min i = 0 n − 1 ( n c 2 + p o w a + p o w b − 2 ∑ j = 0 n − 1 a i + j b j + 2 c ( s u m a − s u m b ) ) egin{aligned} &min_{i=0}^{n-1} sum_{j=0}^{n-1}(a_{i+j}-b_j+c)^2\ =&min_{i=0}^{n-1}sum_{j=0}^{n-1}(a_{i+j}-b_j)^2+c^2+2(a_{i+j}-b_j)c\ =&min_{i=0}^{n-1} [nc^2+sum_{j=0}^{n-1}(a_{i+j}-b_j)^2+2csum_{j=0}^{n-1}(a_{i+j}-b_j)]\ =&min_{i=0}^{n-1} [nc^2+sum_{j=0}^{n-1}(a_{i+j}^2+b_j^2-2a_{i+j}b_j)+2c(sum_{j=0}^{n-1} a_{i+j}-sum_{j=0}^{n-1} b_j)]\ =&min_{i=0}^{n-1} (nc^2+powa+powb-2sum_{j=0}^{n-1}a_{i+j}b_j+2c(suma-sumb)) end{aligned} ====i=0minn−1j=0∑n−1(ai+j−bj+c)2i=0minn−1j=0∑n−1(ai+j−bj)2+c2+2(ai+j−bj)ci=0minn−1[nc2+j=0∑n−1(ai+j−bj)2+2cj=0∑n−1(ai+j−bj)]i=0minn−1[nc2+j=0∑n−1(ai+j2+bj2−2ai+jbj)+2c(j=0∑n−1ai+j−j=0∑n−1bj)]i=0minn−1(nc2+powa+powb−2j=0∑n−1ai+jbj+2c(suma−sumb))
发现 p o w a + p o w b powa+powb powa+powb 是定值,而和 c c c 有关的 n c 2 + 2 c ( s u m a − s u m b ) nc^2+2c(suma-sumb) nc2+2c(suma−sumb) 可以用二次函数最值 O ( 1 ) O(1) O(1) 求,也就是说我们只需要求 − 2 ∑ j = 0 n − 1 a i + j b j -2sum_{j=0}^{n-1}a_{i+j}b_j −2∑j=0n−1ai+jbj 的最小值,即 ∑ j = 0 n − 1 a i + j b j sum_{j=0}^{n-1}a_{i+j}b_j ∑j=0n−1ai+jbj 的最大值。
我们可以把这个形式改变一下,把 ∑ j = 0 n − 1 a i + j b j sum_{j=0}^{n-1}a_{i+j}b_j ∑j=0n−1ai+jbj 改成 ∑ i − j = k a i b j sum_{i-j=k}a_ib_j ∑i−j=kaibj( k k k 是旋转角度)。
令 S ( k ) = ∑ i − j = k a i b j S(k)=sum_{i-j=k}a_ib_j S(k)=∑i−j=kaibj,设 a i ^ = a 2 n − i widehat{a_i}=a_{2n-i} ai =a2n−i,则:
S ( k ) = ∑ i − j = k 0 ≤ j < n a i b j = ∑ ( 2 n − i ) − j = k 0 ≤ j < n a i ^ b j = ∑ i + j = 2 n − k 0 ≤ j < n a i ^ b j S(k)=sum_{i-j=k}^{0leq j< n}a_ib_j=sum_{(2n-i)-j=k}^{0leq j< n}widehat{a_i}b_j=sum_{i+j=2n-k}^{0leq j< n}widehat{a_i}b_j S(k)=i−j=k∑0≤j<naibj=(2n−i)−j=k∑0≤j<nai bj=i+j=2n−k∑0≤j<nai bj
把 n − k n-k n−k 代入:
S ( 2 n − k ) = ∑ i + j = 2 n − ( 2 n − k ) 0 ≤ j < n a i ^ b j = ∑ i + j = k 0 ≤ j < n a i ^ b j S(2n-k)=sum_{i+j=2n-(2n-k)}^{0leq j< n}widehat{a_i}b_j=sum_{i+j=k}^{0leq j< n}widehat{a_i}b_j S(2n−k)=i+j=2n−(2n−k)∑0≤j<nai bj=i+j=k∑0≤j<nai bj
有木有觉得很熟悉?
我们可以把 a i ^ widehat{a_i} ai 看成一个多项式 A A A 的系数, b j b_j bj 看成另一个多项式 B B B 的系数,那么 S ( 2 n − k ) S(2n-k) S(2n−k) 就是 A × B A imes B A×B 第 k k k 项的系数。
看会我们原来的问题:求 ∑ i − j = k a i b j sum_{i-j=k}a_ib_j ∑i−j=kaibj 的最大值,也就是求 S ( k ) S(k) S(k) 的最大值( 0 ≤ k < n 0leq k<n 0≤k<n),也就是 A × B A imes B A×B 第 ( n + 1 ) ∼ 2 n (n+1)sim 2n (n+1)∼2n 项系数的最大值。
那么现在就好处理了,我们先用 FFT 把 A × B A imes B A×B 算出来,再取最大值。
代码如下:
#include<bits/stdc++.h>
#define N 50010
#define PN 262200
#define INF 0x7fffffff
using namespace std;
struct Complex
{
double x,y;
Complex(){};
Complex(double xx,double yy){x=xx,y=yy;}
}a[PN],b[PN],c[PN];
Complex operator + (Complex a,Complex b){return Complex(a.x+b.x,a.y+b.y);}
Complex operator - (Complex a,Complex b){return Complex(a.x-b.x,a.y-b.y);}
Complex operator * (Complex a,Complex b){return Complex(a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x);}
const double pi=acos(-1);
int n,m,suma,sumb,powa,powb;
int limit=1,bit,rev[PN];
void FFT(Complex *a,int opt)
{
for(int i=0;i<limit;i++)
if(i<rev[i])
swap(a[i],a[rev[i]]);
for(int mid=1;mid<limit;mid<<=1)
{
Complex wn=Complex(cos(pi/mid),opt*sin(pi/mid));
for(int i=0,len=(mid<<1);i<limit;i+=len)
{
Complex now=Complex(1,0);
for(int j=0;j<mid;j++,now=now*wn)
{
Complex x=a[i+j],y=now*a[i+mid+j];
a[i+j]=x+y,a[i+mid+j]=x-y;
}
}
}
if(opt==-1)
for(int i=0;i<limit;i++)
a[i].x/=limit;
}
int main()
{
scanf("%d%d",&n,&m);
for(int i=0;i<n;i++)
{
scanf("%lf",&a[i].x);
a[n+i].x=a[i].x;
suma+=a[i].x;
powa+=a[i].x*a[i].x;
}
for(int i=0;i<n;i++)
{
scanf("%lf",&b[i].x);
sumb+=b[i].x;
powb+=b[i].x*b[i].x;
}
int d=round(-1.0*(suma-sumb)/n);
reverse(a,a+2*n+1);//求a^
while(limit<=3*n)
limit<<=1,bit++;
for(int i=0;i<limit;i++)
rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
FFT(a,1),FFT(b,1);
for(int i=0;i<limit;i++)
c[i]=a[i]*b[i];
FFT(c,-1);
int ans=-INF;
for(int i=n+1;i<=2*n;i++)
ans=max(ans,(int)round(c[i].x));//这里用round是怕被卡精度
printf("%d
",n*d*d+2*d*(suma-sumb)+powa+powb-2*ans);
return 0;
}