http://uoj.ac/problem/34
fft真是一个丧心病狂的东西
递归版
#include<cstdio> #include<cmath> #define FOR(i,s,t) for(register int i=s;i<=t;++i) typedef double db; const db pi=acos(-1); const int N=500011; struct complex{ db r,i; typedef complex cp; inline cp operator+(cp A)const{return (cp){r+A.r,i+A.i};} inline cp operator-(cp A)const{return (cp){r-A.r,i-A.i};} inline cp operator*(cp A){return (cp){r*A.r-i*A.i,r*A.i+i*A.r};} }a[N],b[N]; typedef complex cp; inline void fft(cp *x,int n,int type){ if(n==1)return; int hf=n>>1; cp l[hf+10],r[hf+10]; for(register int i=0;i<n;i+=2) l[i>>1]=x[i],r[i>>1]=x[i+1]; fft(l,hf,type);fft(r,hf,type); cp wn=(cp){cos(2*pi/n),sin(type*2*pi/n)},w=(cp){1,0},t; for(register int i=0;i<hf;++i,w=w*wn) t=w*r[i],x[i]=l[i]+t,x[i+hf]=l[i]-t; } int n,m,x; int main(){ scanf("%d%d",&n,&m); FOR(i,0,n)scanf("%d",&x),a[i].r=x; FOR(i,0,m)scanf("%d",&x),b[i].r=x; m+=n;for(n=1;n<=m;n<<=1); fft(a,n,1);fft(b,n,1); FOR(i,0,n)a[i]=a[i]*b[i]; fft(a,n,-1); FOR(i,0,m) printf("%d ",(int)(a[i].r/n+0.5)); return 0; }
迭代版
#include<cstdio> #include<cmath> #include<algorithm> #define gc getchar() #define FOR(i,s,t) for(register int i=s;i<=t;++i) using std::swap; typedef double db; const db pi=acos(-1); struct complex{ db r,i; typedef complex cp; inline cp operator+(cp A)const{return (cp){r+A.r,i+A.i};} inline cp operator-(cp A)const{return (cp){r-A.r,i-A.i};} inline cp operator*(cp A)const{return (cp){r*A.r-i*A.i,r*A.i+A.r*i};} }a[1<<18],b[1<<18],wn[18]; typedef complex cp; int p[1<<18]; int n,m,lg2; inline void fft(cp *a){ FOR(i,0,n-1)if(i<p[i])swap(a[i],a[p[i]]); for(register int i=1,t=0;i<n;i<<=1,++t){ int m=i<<1; cp w=wn[t]; for(register int j=0;j<n;j+=m){ cp v=(cp){1,0}; int e=i+j; for(register int k=j;k<e;++k,v=v*w){ cp y=v*a[k+i];a[k+i]=a[k]-y; a[k]=a[k]+y; } } } } inline int read(){ char c;while(c=gc,c==' '||c==' ');int data=c-48; while(c=gc,c>='0'&&c<='9')data=(data<<1)+(data<<3)+c-48;return data; } int wr[51]; inline void write(int x){ if(!x){ putchar(48); return; } while(x)wr[++wr[0]]=x%10,x/=10; while(wr[0])putchar(48+wr[wr[0]--]); } int main(){ n=read();m=read(); FOR(i,0,n)a[i].r=1.00*read(); FOR(i,0,m)b[i].r=1.00*read(); m+=n;for(n=1;n<=m;n<<=1)++lg2; FOR(i,0,n-1)p[i]=(p[i>>1]>>1)^((i&1)<<(lg2-1)); for(register int i=1,t=0;i<n;i<<=1,++t)wn[t]=(cp){cos(pi/i),sin(pi/i)}; fft(a);fft(b); FOR(i,0,n-1)a[i]=a[i]*b[i]; for(register int i=1,t=0;i<n;i<<=1,++t)wn[t]=(cp){cos(pi/i),sin(-pi/i)}; fft(a); FOR(i,0,m)write((int)(a[i].r/n+0.5)),putchar(' '); return 0; }
ntt
#include<cstdio> #include<algorithm> using namespace std; const int mod=479<<21|1,maxn=1e6; int a[maxn],b[maxn],p[maxn],s[maxn],gn[maxn]; int n,m,lg2,g,ny; inline int fp(int a,int b){ int ret=1; while(b){ if(b&1)ret=1ll*a*ret%mod; a=1ll*a*a%mod; b>>=1; } return ret; } inline int get_g(int p){ register int x=p-1; for(register int i=2;i*i<=x;++i) if(x%i==0){ while(x%i==0)x/=i; s[++s[0]]=i; } if(x>1)s[++s[0]]=x; for(register int i=2;;++i){ for(register int j=1;j<=s[0];++j) if(fp(i,(p-1)/s[j])==1)goto die; return i; die:; } } inline void ntt(int *a){ for(register int i=0;i<m;++i) if(i<p[i])swap(a[i],a[p[i]]); for(register int i=1,t=0,len,w,v;i<m;i<<=1,++t){ len=i<<1; for(register int j=0;j<m;j+=len){ w=1; for(register int k=j;k<i+j;++k,w=1ll*w*gn[t]%mod){ v=1ll*w*a[i+k]%mod; a[i+k]=(a[k]-v+mod)%mod; a[k]=(a[k]+v)%mod; } } } } int main(){ g=get_g(mod); scanf("%d%d",&n,&m); for(register int i=0;i<=n;++i)scanf("%d",a+i); for(register int i=0;i<=m;++i)scanf("%d",b+i); n+=m;for(m=1;m<=n;m<<=1)++lg2; for(register int i=0;i<m;++i)p[i]=(p[i>>1]>>1)^((i&1)<<(lg2-1)); for(register int i=1,t=0;i<m;i<<=1,++t)gn[t]=fp(g,(mod-1)/(i<<1)); ntt(a);ntt(b); for(register int i=0;i<m;++i)a[i]=1ll*a[i]*b[i]%mod; ntt(a); reverse(a+1,a+m); ny=fp(m,mod-2); for(register int i=0;i<m;++i)a[i]=1ll*a[i]*ny%mod; for(register int i=0;i<=n;++i)printf("%d ",a[i]); return 0; }
多项式求逆元
#include<cstdio> #include<algorithm> #include<cstring> using namespace std; const int mod=998244353,maxn=2e5+5; int a[maxn],b[maxn],tmp[maxn],s[maxn],gn[maxn]; int n; inline int fp(int a,int b){ int ret=1; while(b){ if(b&1)ret=1ll*a*ret%mod; a=1ll*a*a%mod;b>>=1; } return ret; } inline void ntt(int *a,int p,int f){ for(register int i=0;i<p;++i) if(i<s[i]) swap(a[i],a[s[i]]); for(register int i=1,t=0,g,w,v;i<p;i<<=1,++t){ g=gn[t]; for(register int j=0;j<p;j+=(i<<1)){ w=1; for(register int k=j;k<i+j;++k,w=1ll*w*g%mod){ v=1ll*w*a[i+k]%mod; a[i+k]=(a[k]-v+mod)%mod; a[k]=(a[k]+v)%mod; } } } if(f==1)return; reverse(a+1,a+p); int ny=fp(p,mod-2); for(register int i=0;i<p;++i) a[i]=1ll*a[i]*ny%mod; } inline void solve(int *b,int deg){ if(deg==1){ b[0]=fp(a[0],mod-2); return; } solve(b,(deg+1)>>1); int p=1,lg2=0;while(p<(deg<<1))p<<=1,++lg2; for(register int i=0;i<p;++i)tmp[i]=i<deg?a[i]:0; for(register int i=((deg+1)>>1);i<p;++i)b[i]=0; for(register int i=0;i<p;++i)s[i]=(s[i>>1]>>1)^((i&1)<<(lg2-1)); ntt(tmp,p,1),ntt(b,p,1); for(register int i=0;i<p;++i)b[i]=(2ll*b[i]%mod-1ll*tmp[i]*b[i]%mod*b[i]%mod+mod)%mod; ntt(b,p,-1); } int main(){ for(register int t=0,i=1;t<=20;i<<=1,++t) gn[t]=fp(3,(mod-1)/(i<<1)); scanf("%d",&n); for(register int i=0;i<=n;++i)scanf("%d",a+i); solve(b,n+1); for(register int i=0;i<=n;++i)printf("%d ",b[i]); return 0; }