设序列a的生成函数$large f(x)=sumlimits_{i=0}^{n-1}a_ix^i$,则操作1,2,3分别对应将$f(x)$乘上$Largefrac{1}{1-x},frac{1}{1-x^2},frac{1}{1-x^3}$,如果操作1,2,3分别进行了p1,p2,p3次,则最终序列的生成函数为$Largefrac{f(x)}{(1-x)^{p_1}(1-x^2)^{p_2}(1-x^3)^{p_3}}$,套个二项式定理+多项式乘法+多项式逆元即可。由于题目中的模数刚好可以NTT,因此直接NTT即可。(ps:浮点数FFT取模常数太大,会TLE)
1 #include<bits/stdc++.h> 2 using namespace std; 3 typedef long long ll; 4 const int N=4e5+10,M=1e6+10,mod=998244353; 5 const int G=3; 6 int n,m,n2,a[N],b[3][N],cnt[3],fac[M],inv[M],invf[M]; 7 int Pow(int x,int p) { 8 int ret=1; 9 for(; p; p>>=1,x=(ll)x*x%mod)if(p&1)ret=(ll)ret*x%mod; 10 return ret; 11 } 12 int C(int n,int m) {return n<m?0:(ll)fac[n]*invf[m]%mod*invf[n-m]%mod;} 13 struct F_FT { 14 int A[N],B[N],b[N],c[N]; 15 void FFT(int* a,int n,int f) { 16 for(int i=1,j=n>>1,k; i<n-1; ++i,j^=k) { 17 if(i<j)swap(a[i],a[j]); 18 for(k=n>>1; j&k; j^=k,k>>=1); 19 } 20 for(int k=1; k<n; k<<=1) { 21 int gn=Pow(G,(mod-1)/(k<<1)); 22 if(f==-1)gn=Pow(gn,mod-2); 23 for(int i=0; i<n; i+=k<<1) { 24 int g=1; 25 for(int j=i; j<i+k; ++j,g=(ll)g*gn%mod) { 26 int x=a[j],y=(ll)g*a[j+k]%mod; 27 a[j]=((ll)x+y)%mod,a[j+k]=((ll)x-y+mod)%mod; 28 } 29 } 30 } 31 if(!~f)for(int i=0; i<n; ++i)a[i]=(ll)a[i]*inv[n]%mod; 32 } 33 void mul(int* a,int* b,int* c,int n) { 34 for(int i=0; i<n; ++i)A[i]=a[i],B[i]=b[i],A[i+n]=B[i+n]=0; 35 n<<=1; 36 FFT(A,n,1),FFT(B,n,1); 37 for(int i=0; i<n; ++i)c[i]=(ll)A[i]*B[i]%mod; 38 FFT(c,n,-1); 39 } 40 void inverse(int* a,int n) { 41 for(int i=0; i<n; ++i)b[i]=0; 42 b[0]=Pow(a[0],mod-2); 43 for(int m=2; m<=n; m<<=1) { 44 mul(b,b,c,m),mul(a,c,c,m); 45 for(int i=0; i<m; ++i)b[i]=(((ll)b[i]*2-c[i])%mod+mod)%mod; 46 } 47 for(int i=0; i<n; ++i)a[i]=b[i]; 48 } 49 } fft; 50 int main() { 51 fac[0]=invf[0]=inv[1]=1; 52 for(int i=2; i<M; ++i)inv[i]=(ll)(mod-mod/i)*inv[mod%i]%mod; 53 for(int i=1; i<M; ++i)fac[i]=(ll)fac[i-1]*i%mod,invf[i]=(ll)invf[i-1]*inv[i]%mod; 54 int T; 55 for(scanf("%d",&T); T--;) { 56 memset(cnt,0,sizeof cnt); 57 memset(a,0,sizeof a); 58 scanf("%d%d",&n,&m); 59 n2=1; 60 for(; n2<n; n2<<=1); 61 for(int i=0; i<n; ++i)scanf("%d",&a[i]); 62 while(m--) { 63 int x; 64 scanf("%d",&x); 65 cnt[x-1]++; 66 } 67 for(int j=0; j<3; ++j) { 68 for(int i=0; i<n2; ++i)b[j][i]=0; 69 for(int i=0; i*(j+1)<n2; ++i)b[j][i*(j+1)]=(ll)C(cnt[j],i)*(i&1?mod-1:1)%mod; 70 if(j)fft.mul(b[0],b[j],b[0],n2); 71 } 72 fft.inverse(b[0],n2),fft.mul(a,b[0],a,n2); 73 ll ans=0; 74 for(int i=0; i<n; ++i)ans^=(ll)a[i]*(i+1); 75 printf("%lld ",ans); 76 } 77 return 0; 78 }
也可以直接利用性质$Largefrac{1}{(1-x)^n}=sumlimits_{i=0}^{n}C_{n-1+i}^{i}x^i$,省去了求逆元的过程。
1 #include<bits/stdc++.h> 2 using namespace std; 3 typedef long long ll; 4 const int N=4e5+10,M=1e6+10,mod=998244353; 5 const int G=3; 6 int n,m,n2,a[N],b[3][N],c[N],cnt[3],fac[M],inv[M],invf[M]; 7 int Pow(int x,int p) { 8 int ret=1; 9 for(; p; p>>=1,x=(ll)x*x%mod)if(p&1)ret=(ll)ret*x%mod; 10 return ret; 11 } 12 int C(int n,int m) {return n<m?0:(ll)fac[n]*invf[m]%mod*invf[n-m]%mod;} 13 struct F_FT { 14 int A[N],B[N],c[N]; 15 void FFT(int* a,int n,int f) { 16 for(int i=1,j=n>>1,k; i<n-1; ++i,j^=k) { 17 if(i<j)swap(a[i],a[j]); 18 for(k=n>>1; j&k; j^=k,k>>=1); 19 } 20 for(int k=1; k<n; k<<=1) { 21 int gn=Pow(G,(mod-1)/(k<<1)); 22 if(f==-1)gn=Pow(gn,mod-2); 23 for(int i=0; i<n; i+=k<<1) { 24 int g=1; 25 for(int j=i; j<i+k; ++j,g=(ll)g*gn%mod) { 26 int x=a[j],y=(ll)g*a[j+k]%mod; 27 a[j]=((ll)x+y)%mod,a[j+k]=((ll)x-y+mod)%mod; 28 } 29 } 30 } 31 if(!~f)for(int i=0; i<n; ++i)a[i]=(ll)a[i]*inv[n]%mod; 32 } 33 void mul(int* a,int* b,int* c,int n) { 34 for(int i=0; i<n; ++i)A[i]=a[i],B[i]=b[i],A[i+n]=B[i+n]=0; 35 n<<=1; 36 FFT(A,n,1),FFT(B,n,1); 37 for(int i=0; i<n; ++i)c[i]=(ll)A[i]*B[i]%mod; 38 FFT(c,n,-1); 39 } 40 } fft; 41 int main() { 42 fac[0]=invf[0]=inv[1]=1; 43 for(int i=2; i<M; ++i)inv[i]=(ll)(mod-mod/i)*inv[mod%i]%mod; 44 for(int i=1; i<M; ++i)fac[i]=(ll)fac[i-1]*i%mod,invf[i]=(ll)invf[i-1]*inv[i]%mod; 45 int T; 46 for(scanf("%d",&T); T--;) { 47 memset(cnt,0,sizeof cnt); 48 memset(a,0,sizeof a); 49 scanf("%d%d",&n,&m); 50 n2=1; 51 for(; n2<n; n2<<=1); 52 for(int i=0; i<n; ++i)scanf("%d",&a[i]); 53 while(m--) { 54 int x; 55 scanf("%d",&x); 56 cnt[x-1]++; 57 } 58 for(int j=0; j<3; ++j) { 59 for(int i=0; i<n2; ++i)b[j][i]=0; 60 if(cnt[j]==0)b[j][0]=1; 61 else for(int i=0; i*(j+1)<n2; ++i)b[j][i*(j+1)]=C(cnt[j]-1+i,i); 62 if(j)fft.mul(b[0],b[j],b[0],n2); 63 } 64 fft.mul(a,b[0],a,n2); 65 ll ans=0; 66 for(int i=0; i<n; ++i)ans^=(ll)a[i]*(i+1); 67 printf("%lld ",ans); 68 } 69 return 0; 70 }