Description
给定一个长度为(N)的数组(A[]),求有多少对(i, j, k(1leqslant i<j<k leqslant N))满足(A[k]-A[j]=A[j]-A[i])。
Solution
分块FFT。
每个暴力求需要(n)次FFT。
分块的话,FFT求块与块之间的,块内的暴力求。
复杂度(O(nsqrt {n}logn))
BZOJ上险些T了qwq...
Code
/************************************************************** Problem: 3509 User: BeiYu Language: C++ Result: Accepted Time:37220 ms Memory:7936 kb ****************************************************************/ #include <bits/stdc++.h> using namespace std; #define debug(a) cout<<#a<<"="<<a<<" " #define mpr make_pair #define r first #define i second typedef pair< double,double > pr; typedef long long LL; const int N = 1e5+50; const int B = 2500; const double Pi = M_PI; pr operator + (const pr &a,const pr &b) { return mpr(a.r+b.r,a.i+b.i); } pr operator - (const pr &a,const pr &b) { return mpr(a.r-b.r,a.i-b.i); } pr operator * (const pr &a,const pr &b) { return mpr(a.r*b.r-a.i*b.i,a.r*b.i+a.i*b.r); } int NN=65536; void Rev(pr a[]) { for(int i=0,j=0;i<NN;i++) { if(i<j) swap(a[i],a[j]); for(int k=NN>>1;(j^=k)<k;k>>=1); } } void DFT(pr a[],int r=1) { Rev(a); for(int i=1;i<=NN;i<<=1) { pr wi=mpr(cos(2.0*Pi/i),r*sin(2.0*Pi/i)); for(int j=0;j<NN;j+=i) { pr w=mpr(1.0,0.0); for(int k=j;k<j+i/2;k++) { pr x=a[k],y=w*a[k+i/2]; a[k]=x+y,a[k+i/2]=x-y; w=w*wi; } } }if(r==-1) for(int i=0;i<NN;i++) a[i].r/=NN; } void FFT(pr a[],pr b[],pr c[]) { DFT(a,1),DFT(b,1); for(int i=0;i<NN;i++) c[i]=a[i]*b[i]; DFT(c,-1); } inline int in(int x=0,char ch=getchar()) { while(ch>'9' || ch<'0') ch=getchar(); while(ch>='0' && ch<='9') x=x*10+ch-'0',ch=getchar();return x; } int n;LL ans; int a[N],b[N]; int bf[N],bd[N],tp[N]; pr x1[N],x2[N],x3[N]; int main() { n=in(); for(int i=0;i<n;i++) a[i]=in(); for(int i=0;i<n;i++) bd[a[i]]++; for(int j=0;j<n;j+=B) { for(int i=j;i<n && i<j+B;i++) bd[a[i]]--; //before and behind memset(x1,0,sizeof(x1)),memset(x2,0,sizeof(x2)); for(int i=0;i<NN/2;i++) x1[i]=mpr(bf[i],0),x2[i]=mpr(bd[i],0); FFT(x1,x2,x3); for(int i=j;i<n && i<j+B;i++) tp[a[i]]++; for(int i=0;i<NN;i++) if(!(i&1)) ans+=(LL)tp[i/2]*(int)(x3[i].r+0.5); for(int i=j;i<n && i<j+B;i++) tp[a[i]]=0; for(int p=j;p<n && p<j+B;p++) for(int q=p+1;q<n && q<j+B;q++) { if(2*a[p]-a[q]>=0) ans+=bf[2*a[p]-a[q]]; if(2*a[q]-a[p]>=0) ans+=bd[2*a[q]-a[p]]; } for(int p=j;p<n && p<j+B;p++) { for(int q=j;q<p;q++) { if(2*a[q]-a[p]>=0) ans+=tp[2*a[q]-a[p]]; tp[a[q]]++; } for(int q=j;q<p;q++) tp[a[q]]--; } for(int i=j;i<n && i<j+B;i++) bf[a[i]]++; }cout<<ans<<endl; return 0; }