今天把最近写的一些好的题解放上来记录一下。
这道题是BZOJ4827=LG3723.
知识点:
FFT,NTT
题意:
给两个环,可以转但是不可以翻,要求给其中一个环每个元素加上一个非负整数,求每一项差的平方之和最小。
解法:
任意一个环加上一个非负整数相当于其中一个环加上一个整数。列出式子((xin Z),是我们要加上的那个数):(displaystylesum_{i=1}^{n}(a_i-b_i+x)^2=displaystylesum_{i=1}^{n}(a_i^2+b_i^2+x^2+2a_ix-2b_ix-2a_ib_i)=displaystylesum_{i=1}^{n}a_i^2+displaystylesum_{i=1}^{n}b_i^2+nx^2+2xdisplaystylesum_{i=1}^{n}(a_i-b_i)-2displaystylesum_{i=1}^{n}a_ib_i)。所以最后一项是个卷积,卷一下求最大值即可(因为倍长过a,所以找最大值从第n+1到2n项找,原式卷起来有效的就是那些地方)。中间两项是个二次函数,求一下顶点左右的两个整数即可,前两项是常数。
备注:
这种题见到环先破环成链,倍长变成序列,卷出来的答案必定在n+1到2n项(想想原来的是什么,现在的是什么)。然后就都是套路了。
代码:
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
using namespace std;
typedef long long ll;
const int maxn=400010,mod=998244353,G=3,invG=332748118;
int n,m,a[maxn],b[maxn],c[maxn],r[maxn],ta[maxn],tb[maxn];
ll ans;
int read()
{
int x=0;
char c=getchar();
while (c<48||c>57)
c=getchar();
while (c>=48&&c<=57)
x=(x<<1)+(x<<3)+(c^48),c=getchar();
return x;
}
int qp(int a,int b)
{
int res=1;
for (;b;b>>=1)
{
if (b&1)
res=1ll*res*a%mod;
a=1ll*a*a%mod;
}
return res;
}
int init(int len)
{
int N,i,j;
for (N=1;N<len;N<<=1);
j=(N>>1);
for (i=1;i<=N-2;i++)
r[i]=(r[i>>1]>>1)|((i&1)*j);
return N;
}
void ntt(int *f,int N,bool pos)
{
int i,j,k,len,t,tmp,w;
for (i=1;i<=N-2;i++)
if (i<r[i])
swap(f[i],f[r[i]]);
for (k=2;k<=N;k<<=1)
{
len=(k>>1);
tmp=qp(pos?G:invG,(mod-1)/k);
for (j=0;j<N;j+=k)
{
w=1;
for (i=j;i<j+len;i++)
{
t=1ll*w*f[i+len]%mod;
f[i+len]=f[i]-t;
if (f[i+len]<0)
f[i+len]+=mod;
f[i]=f[i]+t;
if (f[i]>=mod)
f[i]-=mod;
w=1ll*w*tmp%mod;
}
}
}
if (!pos)
{
int invN=qp(N,mod-2);
for (i=0;i<N;i++)
f[i]=1ll*f[i]*invN%mod;
}
}
void poly_mul(int *A,int n,int *B,int m,int *C,int lim=0)
{
int i,N=init(n+m-1);
memcpy(ta,A,n*4);
fill(ta+n,ta+N,0);
memcpy(tb,B,m*4);
fill(tb+m,tb+N,0);
ntt(ta,N,1),ntt(tb,N,1);
for (i=0;i<N;i++)
C[i]=1ll*ta[i]*tb[i]%mod;
ntt(C,N,0);
memset(ta,0,N*4+4);
memset(tb,0,N*4+4);
if (lim)
fill(C+lim,C+N,0);
}
int main()
{
int n=read(),i=read(),p=0,q=0,tmp=0,x;
for (i=0;i<n;i++)
a[i]=a[i+n]=read(),ans+=a[i]*a[i];
for (i=0;i<n;i++)
b[i]=read(),ans+=b[i]*b[i],q+=a[i]-b[i];
q*=2,p=n;
reverse(b,b+n);
poly_mul(a,n*2,b,n,c,n*2);
for (i=n;i<n*2;i++)
tmp=max(tmp,c[i]);
tmp*=2;
ans-=tmp;
x=(int)floor((-1.0*q)/(2.0*p));
tmp=1ll*x*x*p+1ll*x*q;
x=(int)ceil((-1.0*q)/(2.0*p));
ans+=1ll*min(1ll*tmp,1ll*x*x*p+1ll*x*q);
printf("%lld
",ans);
return 0;
}