题目:https://www.luogu.org/problemnew/show/P4721
分治做法,考虑左边对右边的贡献即可;
注意最大用到的 a 的项也不过是 a[r-l] ,所以 NTT 可以只做到 2*(r-l),能快一倍。
代码如下:
#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> using namespace std; typedef long long ll; int const xn=(1<<18),mod=998244353; int n,f[xn],g[xn],a[xn],b[xn],rev[xn]; int rd() { int ret=0,f=1; char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')f=0; ch=getchar();} while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar(); return f?ret:-ret; } ll pw(ll a,int b) { ll ret=1; for(;b;b>>=1,a=(a*a)%mod)if(b&1)ret=(ret*a)%mod; return ret; } int upt(int x){while(x>=mod)x-=mod; while(x<0)x+=mod; return x;} void ntt(int *a,int tp,int lim) { for(int i=0;i<lim;i++) if(i<rev[i])swap(a[i],a[rev[i]]); for(int mid=1;mid<lim;mid<<=1) { int wn=pw(3,tp==1?(mod-1)/(mid<<1):(mod-1)-(mod-1)/(mid<<1)); for(int j=0,len=(mid<<1);j<lim;j+=len) for(int k=0,w=1;k<mid;k++,w=(ll)w*wn%mod) { int x=a[j+k],y=(ll)w*a[j+mid+k]%mod; a[j+k]=upt(x+y); a[j+mid+k]=upt(x-y); } } if(tp==1)return; int inv=pw(lim,mod-2); for(int i=0;i<lim;i++)a[i]=(ll)a[i]*inv%mod; } void work(int l,int r) { if(l==r)return; int len=r-l+1,mid=((l+r)>>1); work(l,mid); int lim=1,L=0; while(lim<=(r-l))lim<<=1,L++;//max:r-l for(int i=0;i<lim;i++)rev[i]=((rev[i>>1]>>1)|((i&1)<<(L-1))); for(int i=0;i<lim;i++)a[i]=b[i]=0;// for(int i=l;i<=mid;i++)a[i-l]=f[i]; for(int i=0;i<len;i++)b[i]=g[i]; ntt(a,1,lim); ntt(b,1,lim); for(int i=0;i<lim;i++)a[i]=(ll)a[i]*b[i]%mod; ntt(a,-1,lim); for(int i=mid+1;i<=r;i++)f[i]=upt(f[i]+a[i-l]); work(mid+1,r); } int main() { n=rd(); f[0]=1; for(int i=1;i<n;i++)g[i]=rd(); work(0,n-1); for(int i=0;i<n;i++)printf("%d ",f[i]); puts(""); return 0; }
多项式求逆做法感觉很妙:
设 ( F(x) = sum f_{i}*x_{i} ),( G(x) = sum g_{i}*x_{i} )
则 ( F(x) * G(x) = sum x_{i} * sumlimits_{j=0}^{i} f_{j}*g_{i-j} )
即 ( F(x) * G(x) = F(x) - f_{0}*x_{0} )
所以 ( F(x) = (1-G(x))^{-1} )
多项式求逆即可。
代码如下:
#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> using namespace std; typedef long long ll; int const xn=(1<<18),mod=998244353; int n,f[xn],g[xn],c[xn],rev[xn]; int rd() { int ret=0,f=1; char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')f=0; ch=getchar();} while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar(); return f?ret:-ret; } ll pw(ll a,int b) { ll ret=1; for(;b;b>>=1,a=(a*a)%mod)if(b&1)ret=(ret*a)%mod; return ret; } int upt(int x){while(x>=mod)x-=mod; while(x<0)x+=mod; return x;} void ntt(int *a,int tp,int lim) { for(int i=0;i<lim;i++) if(i<rev[i])swap(a[i],a[rev[i]]); for(int mid=1;mid<lim;mid<<=1) { int p=mod-1,len=(mid<<1),wn=pw(3,tp==1?p/len:p-p/len); for(int j=0;j<lim;j+=len) for(int k=0,w=1;k<mid;k++,w=(ll)w*wn%mod) { int x=a[j+k],y=(ll)w*a[j+mid+k]%mod; a[j+k]=upt(x+y); a[j+mid+k]=upt(x-y); } } if(tp==1)return; int inv=pw(lim,mod-2); for(int i=0;i<lim;i++)a[i]=(ll)a[i]*inv%mod; } void inv(int *a,int *b,int n) { if(n==1){b[0]=pw(a[0],mod-2); return;} inv(a,b,(n+1)>>1); int lim=1,l=0; while(lim<n+n)lim<<=1,l++; for(int i=0;i<lim;i++)rev[i]=((rev[i>>1]>>1)|((i&1)<<(l-1))); for(int i=0;i<n;i++)c[i]=a[i]; for(int i=n;i<lim;i++)c[i]=0; ntt(c,1,lim); ntt(b,1,lim); for(int i=0;i<lim;i++)b[i]=((ll)2-(ll)c[i]*b[i])%mod*b[i]%mod; ntt(b,-1,lim); for(int i=n;i<lim;i++)b[i]=0; } int main() { n=rd(); f[0]=1; g[0]=1; for(int i=1;i<n;i++)g[i]=-rd(); inv(g,f,n); for(int i=0;i<n;i++)printf("%d ",f[i]); puts(""); return 0; }