BZOJ
Luogu
题意:
给定(n),(m),(x_i),(y_i),求(sum_{i=1}^{n}(x_{i+k}-y_i+c)^2)的最小值(其中(kin[0,n)),(cin[-m,m]))
上面的表述是默认(x_{i+n}=x_i),因为这是一个环呀
sol
那个。。。关于(cin[-m,m])应该没什么问题吧。就是说(c)显然不会无限增大,当(c>m)或(c<-m)时(c)再增大或者减小就没有任何意义了(因为其中一个手环的亮度值已经完全大于另一个)
首先我们大力拆式子
原式=
[sum_{i=1}^{n}[x_{i+k}^2+y_i^2+c^2-2*x_{i+k}y_i+2*(x_{i+k}-y_i)*c]
]
把外面的(sum)拆了
[=sum_{i=1}^{n}x_i^2+sum_{i=1}^{n}y_i^2+n*c^2-2sum_{i=1}^{n}x_{i+k}y_i+2(sum_{i=1}^{n}x_i-sum_{i=1}^{n}y_i)*c
]
发现(k)跟(c)没什么关系,所以我们把关于(k)和(c)的分别提出来就可以了。
我们从容易的入手。
与(c)和(k)都没关系的项:
[sum_{i=1}^{n}x_i^2+sum_{i=1}^{n}y_i^2
]
直接(O(n))搞出来不解释。
与(c)有关的项:
[n*c^2+2(sum_{i=1}^{n}x_i-sum_{i=1}^{n}y_i)*c
]
那个系数先(O(n))搞出来,再直接枚举(c)做(O(m))的计算即可。
与(k)有关的项:
[sum_{i=1}^{n}x_{i+k}y_i
]
我们需要最大化这个东西(因为前面是减号嘛)
我们往FFT上面靠
我们把(y)反过来,上式变成
[sum_{i=1}^{n}x_{i+k}y_{n-i+1}
]
发现这个是(x)与(y)的卷积的第(n+k+1)次项系数!
所以直接求(x)和(y)的卷积然后在系数里面取个(max)即可。注意(x)要倍长,(y)要反向。
最后把所有的东西加在一起即可。
code
#include<cstdio>
#include<algorithm>
#include<cmath>
#include<complex>
using namespace std;
#define ll long long
const int MAX = 300005;
const double Pi = acos(-1);
int gi()
{
int x=0,w=1;char ch=getchar();
while ((ch<'0'||ch>'9')&&ch!='-') ch=getchar();
if (ch=='-') w=0,ch=getchar();
while (ch>='0'&&ch<='9') x=(x<<3)+(x<<1)+ch-'0',ch=getchar();
return w?x:-x;
}
int N,M,n,m,x[MAX],y[MAX],r[MAX],l,k;
complex<double>a[MAX],b[MAX];
ll tot,ans;
void FFT(complex<double>*P,int opt)
{
for (int i=1;i<n;i++)
if (i<r[i]) swap(P[i],P[r[i]]);
for (int i=1;i<n;i<<=1)
{
complex<double>W(cos(Pi/i),opt*sin(Pi/i));
for (int p=i<<1,j=0;j<n;j+=p)
{
complex<double>w(1,0);
for (int k=0;k<i;k++,w*=W)
{
complex<double>X=P[j+k],Y=w*P[j+k+i];
P[j+k]=X+Y;P[j+k+i]=X-Y;
}
}
}
}
int main()
{
N=gi();M=gi();
for (int i=1;i<=N;i++) x[i]=gi();
for (int i=1;i<=N;i++) y[i]=gi();
for (int i=1;i<=N;i++)
tot+=x[i]*x[i]+y[i]*y[i],k+=x[i]-y[i];
for (int i=1;i<=N;i++)
a[i]=a[i+N]=x[i],b[i]=y[N-i+1];
m=3*N;
for (n=1;n<=m;n<<=1) l++;l--;
for (int i=0;i<n;i++) r[i]=(r[i>>1]>>1)|((i&1)<<l);
FFT(a,1);FFT(b,1);
for (int i=0;i<n;i++) a[i]=a[i]*b[i];
FFT(a,-1);
for (int i=0;i<N;i++) ans=max(ans,(ll)(a[N+i+1].real()/n+0.5));
tot-=2*ans;ans=1e18;
for (int c=-M;c<=M;c++)
ans=min(ans,tot+N*c*c+2*k*c);
printf("%lld
",ans);
return 0;
}