题目链接
(Solution)
应为我们可以将任意一个数列加上一个非负整数,即可以变为将一个数列加上一个整数(可以为负),我们将这个整数设为(z).所以要求的式子的变为:
[sum_{i=1}^{n}(x_i-y_i+z)^2
]
首先来化简一下式子
[sum_{i=1}^{n}(x_i-y_i+z)^2
]
[sum_{i=1}^{n}x_i^2+sum_{i=1}^{n}y_i^2+sum_{i=1}^{n}z_i^2+2zsum_{i=1}^{n}(x_i-y_i)-2sum_{i=1}^{n}x_i*y_i
]
我们可以发现不管如何变化(sum_{i=1}^{n}x_i^2+sum_{i=1}^{n}y_i^2)的值都不会变.
然后再看看(sum_{i=1}^{n}z_i^2+2zsum_{i=1}^{n}(x_i-y_i))显然这是一个关于(z)的二次函数.最小值可以(O(1))算出但是注意一下(z)必须是整数而且可以为负,所以需要将(-frac{x_i-y_i}{n})向上和向下取整并带入式子取最小值.
所以我们现在只需要算出(2sum_{i=1}^{n}x_i*y_i)的最大值即可.
这个如何去算,将(y)变成链,在进行一次(fft)在取一个最大值就好了
(Code)
#include<bits/stdc++.h>
#define int long long
#define rg register
#define file(x) freopen(x".in","r",stdin);freopen(x".out","w",stdout);
using namespace std;
int read(){
int x=0,f=1;
char c=getchar();
while(c<'0'||c>'9') f=(c=='-')?-1:1,c=getchar();
while(c>='0'&&c<='9') x=x*10+c-48,c=getchar();
return f*x;
}
const int N=3000001;
const double pi=3.14159265358979323846;
struct node {
double x,y;
node operator + (node z)const {
return (node){x+z.x,y+z.y};
}
node operator - (node z)const {
return (node){x-z.x,y-z.y};
}
node operator * (node z)const {
return (node){x*z.x-y*z.y,y*z.x+x*z.y};
}
}f[N],g[N];
int r[N],limit=1;
void fft(node *a,int opt){
for(int i=0;i<=limit;i++)
if(i<r[i])
swap(a[i],a[r[i]]);
for(int i=1;i<limit;i<<=1){
node w=(node){cos(pi/i),opt*sin(pi/i)};
for(int j=0;j<limit;j+=i<<1){
node l=(node){1,0};
for(int k=j;k<j+i;k++){
node p=l*a[k+i];
a[k+i]=a[k]-p;
a[k]=a[k]+p;
l=l*w;
}
}
}
}
int a[N],b[N];
main(){
int n=read(),k=read(),l=0,m=n,js=0,ans=0;
for(int i=1;i<=n;i++)
a[i]=f[i].x=read(),f[i+n].x=f[i].x,js+=a[i],ans+=a[i]*a[i];
for(int i=1;i<=n;i++)
b[i]=read(),js-=b[i],ans+=b[i]*b[i];
reverse(b+1,b+n+1);
for(int i=1;i<=n;i++)
g[i].x=b[i];
int A=-floor(js*1.0/n),B=-ceil(js*1.0/n);
ans+=min(A*A*n+2*A*js,B*B*n+2*B*js);
n*=2;
while(limit<=n+m)
limit<<=1,l++;
for(int i=0;i<limit;i++)
r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
fft(f,1),fft(g,1);
for(int i=0;i<=limit;i++)
f[i]=f[i]*g[i];
fft(f,-1);
int maxx=0;
for(int i=m;i<=n+1;i++)
maxx=max((int)(f[i].x/limit+0.5),maxx);
printf("%lld",ans-2*maxx);
}