来自FallDream的博客,未经允许,请勿转载,谢谢。
题意:给你n个数ai,求有多少个数对(i,j,k)满足$1leqslant i<j<kleqslant n$且aj*2=ai+ak n<=100000 ai<=30000
题解:考虑构造一个生成函数,只要把左右的生成函数乘起来,然后枚举i就行了。但是每次平方都需要$nlogn$的时间,总复杂度$n^{2}logn$,不能过。
考虑分块,块外的(即满足$1leqslant i<Lleqslant jleqslant R<kleqslant n$的)用生成函数处理一下,块内的直接暴力算.
块的大小是k的时候,复杂度$k^{2}*frac{n}{k}+frac{n}{k}nlogn$,得出k大概等于$sqrt{nlogn}$时候最小,复杂度$nsqrt{nlogn}$
#include<iostream> #include<cstdio> #include<cstring> #include<cmath> #define getchar() (*S++) #define MN 100000 #define pi acos(-1) #define ll long long char B[1<<26],*S=B; using namespace std; int X;char ch; inline int read() { X = 0 , ch = getchar(); while(ch < '0' || ch > '9') ch = getchar(); while(ch >= '0' && ch <= '9'){X = X * 10 + ch - '0';ch = getchar();} return X; } struct cp{ double r,u; cp(double _r=0,double _u=0):r(_r),u(_u){} cp operator+(cp b){return cp(r+b.r,u+b.u);} cp operator-(cp b){return cp(r-b.r,u-b.u);} cp operator*(cp b){return cp(r*b.r-u*b.u,r*b.u+u*b.r);} cp operator/(double y){return cp(r/y,u/y);} }w[2][MN],a[MN],b[MN]; int n,s[MN+5],size,N;ll num1[MN],num2[MN]; ll ans=0; void init(int mx) { w[0][0]=w[1][mx]=cp(1,0); w[0][1]=w[1][mx-1]=cp(cos(2*pi/mx),sin(2*pi/mx)); for(int i=2;i<=mx;i++) w[0][i]=w[1][mx-i]=w[0][i-1]*w[0][1]; } void fft(cp*x,int b) { for(register int i=0,j=0;i<N;++i) { if(i>j)swap(x[i],x[j]); for(int k=N>>1;(j^=k)<k;k>>=1); } for(register int i=2;i<=N;i<<=1)for(register int j=0;j<N;j+=i)for(register int k=0;k<i>>1;k++) { cp t=x[j+k+(i>>1)]*w[b][N/i*k]; x[j+k+(i>>1)]=x[j+k]-t; x[j+k]=x[j+k]+t; } if(b)for(register int i=0;i<N;i++)x[i]=x[i]/N; } int main() { fread(B,1,1<<26,stdin); n=read();size=min(n,6*(int)sqrt(n)); for(register int i=1;i<=n;++i)s[i]=read(),++num2[s[i]]; for(register int i=1;i<n;i+=size) { int r=min(i+size-1,n); for(register int j=i;j<=r;++j)--num2[s[j]]; for(register int j=i;j<=r;++j) { for(int k=j+1;k<=r;k++) { int x=2*s[j]-s[k]; if(x>=0)ans+=num1[x]; x=2*s[k]-s[j]; if(x>=0)ans+=num2[x]; } ++num1[s[j]]; } } for(register int i=1;i<n;i+=size) { int r=min(i+size-1,n),mx=0; for(register int j=1;j<i;++j)a[s[j]].r++,mx=max(mx,s[j]); for(register int j=r+1;j<=n;++j)b[s[j]].r++,mx=max(mx,s[j]); for(N=1;N<=mx;N<<=1);N<<=1; init(N);fft(a,0);fft(b,0); for(register int j=0;j<N;++j)a[j]=a[j]*b[j]; fft(a,1); for(register int j=i;j<=r;++j) ans+=(ll)(a[s[j]<<1].r+0.5); memset(a,0,sizeof(cp)*N); memset(b,0,sizeof(cp)*N); } printf("%lld ",ans); return 0; }