多项式
快速傅里叶变换
使用 FFT 的场合比较少,一般都是要结合 MTT 之类的.
对于复数 $(x,y)$,有 3 种运算:
$(x,y)+(x',y')=(x+x',y+y')$
$(x,y)-(x',y')=(x-x',y-y')$
$(x,y)*(x',y')=(x*x'-y*y',x*y'+y*x')$
#include <cstdio>
#include <vector>
#include <cmath>
#include <cstring>
#include <algorithm>
#define ll long long
#define db long double
#define pb push_back
#define N 1000007
#define setIO(s) freopen(s".in","r",stdin)
using namespace std;
const db pi=acos(-1.0);
struct cp {
db x,y;
cp(db a=0,db b=0) { x=a,y=b; }
cp operator+(const cp &b) const { return cp(x+b.x,y+b.y); }
cp operator-(const cp &b) const { return cp(x-b.x,y-b.y); }
cp operator*(const cp &b) const { return cp(x*b.x-y*b.y,x*b.y+y*b.x); }
}A[N<<2],B[N<<2];
void FFT(cp *a,int len,int op) {
for(int i=0,k=0;i<len;++i) {
if(i>k) swap(a[i],a[k]);
for(int j=len>>1;(k^=j)<j;j>>=1);
}
for(int l=1;l<len;l<<=1) {
cp wn(cos(pi/l),op*sin(pi/l)),x,y;
for(int i=0;i<len;i+=l<<1) {
cp w(1,0);
for(int j=0;j<l;++j) {
x=a[i+j],y=w*a[i+j+l];
a[i+j]=x+y;
a[i+j+l]=x-y;
w=w*wn;
}
}
}
if(op==-1) {
for(int i=0;i<len;++i) a[i].x/=len;
}
}
int main() {
// setIO("input");
int n,m,lim,x;
scanf("%d%d",&n,&m);
for(lim=1;lim<(n+m+1);lim<<=1);
for(int i=0;i<=n;++i) {
scanf("%d",&x),A[i].x=(db)x;
}
for(int i=0;i<=m;++i) {
scanf("%d",&x),B[i].x=(db)x;
}
FFT(A,lim,1),FFT(B,lim,1);
for(int i=0;i<lim;++i) A[i]=A[i]*B[i];
FFT(A,lim,-1);
for(int i=0;i<=n+m;++i) {
printf("%d ",(int)(A[i].x+0.5));
}
return 0;
}
任意模数NTT (MTT)
当模数不能写成 $a imes 2^k+1$ 的时候就需要用到拆系数 FFT (MTT)了.
令 $f(x)=wf_{0}(x)+f_{1}(x)$,$g(x)$ 同理.
然后 $f*g=(wf_{0}+f_{1})(wg_{0}+g_{1})=(f_{0}g_{0})w^2+(f_{0}g_{1}+f_{1}g_{0})w+f_{1}g_{1}$.
做 7 次 FFT 即可,这个 w 选 $2^{15}$ 就好了.
#include <cstdio>
#include <vector>
#include <cmath>
#include <cstring>
#include <algorithm>
#define ll long long
#define db long double
#define pb push_back
#define N 100007
#define setIO(s) freopen(s".in","r",stdin)
using namespace std;
const db pi=acos(-1.0);
struct cp {
db x,y;
cp(db a=0,db b=0) { x=a,y=b; }
cp operator+(const cp &b) const { return cp(x+b.x,y+b.y); }
cp operator-(const cp &b) const { return cp(x-b.x,y-b.y); }
cp operator*(const cp &b) const { return cp(x*b.x-y*b.y,x*b.y+y*b.x); }
}f[2][N<<2],g[2][N<<2],ans[3][N<<2];
int A[N],B[N];
int lim;
ll C[N];
void FFT(cp *a,int len,int op) {
for(int i=0,k=0;i<len;++i) {
if(i>k) swap(a[i],a[k]);
for(int j=len>>1;(k^=j)<j;j>>=1);
}
for(int l=1;l<len;l<<=1) {
cp wn(cos(pi/l),op*sin(pi/l)),x,y;
for(int i=0;i<len;i+=l<<1) {
cp w(1,0);
for(int j=0;j<l;++j) {
x=a[i+j],y=w*a[i+j+l];
a[i+j]=x+y;
a[i+j+l]=x-y;
w=w*wn;
}
}
}
}
ll nor(db x,ll mod) {
return (ll)((ll)(x/lim+0.5)%mod+mod)%mod;
}
void MTT(int *a,int n,int *b,int m,ll mod,ll *c) {
for(lim=1;lim<=(n+m);lim<<=1);
for(int i=0;i<=n;++i) {
f[0][i].x=a[i]>>15;
f[1][i].x=a[i]&0x7fff;
}
for(int i=0;i<=m;++i) {
g[0][i].x=b[i]>>15;
g[1][i].x=b[i]&0x7fff;
}
FFT(f[0],lim,1),FFT(f[1],lim,1);
FFT(g[0],lim,1),FFT(g[1],lim,1);
for(int i=0;i<lim;++i) {
ans[0][i]=f[0][i]*g[0][i];
ans[1][i]=f[0][i]*g[1][i]+f[1][i]*g[0][i];
ans[2][i]=f[1][i]*g[1][i];
}
FFT(ans[0],lim,-1);
FFT(ans[1],lim,-1);
FFT(ans[2],lim,-1);
for(int i=0;i<=n+m;++i) {
ll x=(nor(ans[0][i].x,mod)<<30ll)%mod;
ll y=(nor(ans[1][i].x,mod)<<15ll)%mod;
ll z=nor(ans[2][i].x,mod)%mod;
c[i]=((x+y)%mod+z)%mod;
}
}
int main() {
//setIO("input");
int n,m;
ll mod;
scanf("%d%d%lld",&n,&m,&mod);
for(int i=0;i<=n;++i) scanf("%d",&A[i]);
for(int i=0;i<=m;++i) scanf("%d",&B[i]);
MTT(A,n,B,m,mod,C);
for(int i=0;i<=n+m;++i) printf("%lld ",C[i]);
return 0;
}
多项式求逆
公式 $B=2B'-AB'^2$
这里注意复制 $A$ 数组的时候不要复制多了,否则会让前面的 B 多算.
#include <cstdio>
#include <vector>
#include <cstring>
#include <algorithm>
#define N 100008
#define ll long long
#define pb push_back
#define mod 998244353
#define setIO(s) freopen(s".in","r",stdin)
using namespace std;
int A[N<<2],B[N<<2],f[N<<1],g[N<<1];
int qpow(int x,int y) {
int tmp=1;
for(;y;y>>=1,x=(ll)x*x%mod) {
if(y&1) tmp=(ll)tmp*x%mod;
}
return tmp;
}
int get_inv(int x) {
return qpow(x,mod-2);
}
void NTT(int *a,int len,int op) {
for(int i=0,k=0;i<len;++i) {
if(i>k) swap(a[i],a[k]);
for(int j=len>>1;(k^=j)<j;j>>=1);
}
for(int l=1;l<len;l<<=1) {
int wn=qpow(3,(mod-1)/(l<<1));
if(op==-1) {
wn=get_inv(wn);
}
for(int i=0;i<len;i+=l<<1) {
int w=1,x,y;
for(int j=0;j<l;++j) {
x=a[i+j],y=(ll)w*a[i+j+l]%mod;
a[i+j]=(ll)(x+y)%mod;
a[i+j+l]=(ll)(x-y+mod)%mod;
w=(ll)w*wn%mod;
}
}
}
if(op==-1) {
int iv=get_inv(len);
for(int i=0;i<len;++i) {
a[i]=(ll)a[i]*iv%mod;
}
}
}
void get_inv(int *a,int *b,int len,int la) {
if(len==1) {
b[0]=get_inv(a[0]);
return;
}
get_inv(a,b,len>>1,la);
int l=len<<1;
for(int i=0;i<min(len,la);++i) A[i]=a[i];
for(int i=0;i<len>>1;++i) B[i]=b[i];
for(int i=min(len,la);i<l;++i) A[i]=0;
for(int i=len>>1;i<l;++i) B[i]=0;
NTT(A,l,1),NTT(B,l,1);
for(int i=0;i<l;++i) {
A[i]=(ll)A[i]*B[i]%mod*B[i]%mod;
}
NTT(A,l,-1);
for(int i=0;i<len;++i) {
b[i]=(ll)((ll)(b[i]<<1)%mod-A[i]+mod)%mod;
}
}
int main() {
// setIO("input");
int n,lim;
scanf("%d",&n);
for(int i=0;i<n;++i) {
scanf("%d",&g[i]);
}
for(lim=1;lim<n;lim<<=1);
get_inv(g,f,lim,n);
for(int i=0;i<n;++i) {
printf("%d ",f[i]);
}
return 0;
}
分治NTT
#include <cstdio>
#include <cstring>
#include <algorithm>
#define N 100008
#define ll long long
#define mod 998244353
#define setIO(s) freopen(s".in","r",stdin)
using namespace std;
int A[N<<2],B[N<<2],f[N],g[N];
int qpow(int x,int y) {
int tmp=1;
for(;y;y>>=1,x=(ll)x*x%mod) {
if(y&1) tmp=(ll)tmp*x%mod;
}
return tmp;
}
int get_inv(int x) {
return qpow(x,mod-2);
}
void NTT(int *a,int len,int op) {
for(int i=0,k=0;i<len;++i) {
if(i>k) swap(a[i],a[k]);
for(int j=len>>1;(k^=j)<j;j>>=1);
}
for(int l=1;l<len;l<<=1) {
int wn=qpow(3,(mod-1)/(l<<1)),x,y,w;
if(op==-1) {
wn=get_inv(wn);
}
for(int i=0;i<len;i+=l<<1) {
w=1;
for(int j=0;j<l;++j) {
x=a[i+j],y=(ll)a[i+j+l]*w%mod;
a[i+j]=(ll)(x+y)%mod;
a[i+j+l]=(ll)(x-y+mod)%mod;
w=(ll)w*wn%mod;
}
}
}
if(op==-1) {
int iv=get_inv(len);
for(int i=0;i<len;++i) {
a[i]=(ll)a[i]*iv%mod;
}
}
}
void solve(int l,int r) {
if(l==r) {
return;
}
int mid=(l+r)>>1,lim,s1=0,s2=0;
solve(l,mid);
for(int i=l;i<=mid;++i) A[s1++]=f[i];
for(int i=0;i<=r-l;++i) B[s2++]=g[i];
for(lim=1;lim<(s1+s2);lim<<=1);
for(int i=s1;i<lim;++i) A[i]=0;
for(int i=s2;i<lim;++i) B[i]=0;
NTT(A,lim,1),NTT(B,lim,1);
for(int i=0;i<lim;++i) A[i]=(ll)A[i]*B[i]%mod;
NTT(A,lim,-1);
for(int i=mid+1;i<=r;++i) {
(f[i]+=A[i-l])%=mod;
}
solve(mid+1,r);
}
int main() {
// setIO("input");
int n;
scanf("%d",&n);
for(int i=1;i<n;++i) {
scanf("%d",&g[i]);
}
f[0]=1;
solve(0,n-1);
for(int i=0;i<n;++i) {
printf("%d ",f[i]);
}
return 0;
}