多项式乘法
FFT
见这里
NTT
可以求出两个多项式相乘结果系数对任意NTT模数(可以表示为(a imes2^b+1)形式的质数)取模的结果。
其实只要把FFT里的单位副根变为该模数的原根就好了。
常见的NTT模数为998244353,原根为3。
多项式求逆
见这里
多项式板子
包括了NTT,求逆,ln,exp,快速幂。
看起来并不是非常高效
namespace poly{
#define vi vector<int>
#define ci const int&
#define Red(x) (x+=(x>>31)&mod)
const int LM=(1<<22),mod=998244353;
int lm,lg[LM+10],rev[LM+10],rt[LM+10][2],iv[LM+10],*p,*q;
int POW(int x,int y){
int ret=1;
while(y)y&1?ret=1ll*ret*x%mod:0,x=1ll*x*x%mod,y>>=1;
return ret;
}
void NTT(vi&f,ci op){
int tn=f.size(),l=lg[tn],r,t1,t2;
long long nr;
for(int i=0;i<tn;++i)rev[i]=(rev[i>>1]>>1)+(i&1)*(1<<l-1),rev[i]<i?swap(f[rev[i]],f[i]),0:0;
for(int i=2;i<=tn;i<<=1){
r=rt[i][op];
for(int j=0;j<tn;j+=i){
nr=1,p=&f[j],q=&f[j+(i>>1)];
for(int k=j;k<j+(i>>1);++k,nr=nr*r%mod,++p,++q)t1=*p,t2=nr*(*q)%mod,(*p)-=(((*p)=t1+t2)>=mod?mod:0),(*q)=t1-t2,Red((*q));
}
}
if(op)for(int i=0;i<tn;++i)f[i]=1ll*f[i]*iv[tn]%mod;
}
vi Poly(ci x){
vi ret;
return ret.push_back(x),ret;
}
vi Plus(vi x,vi y){
int sz=max(x.size(),y.size());
x.resize(sz),y.resize(sz);
for(int i=0;i<sz;++i)(x[i]+=y[i])>=mod?x[i]-=mod:0;
return x;
}
vi Minus(vi x,vi y){
int sz=max(x.size(),y.size());
x.resize(sz),y.resize(sz);
for(int i=0;i<sz;++i)x[i]-=y[i],Red(x[i]);
return x;
}
vi Mul(vi x,ci y){
for(int i=0;i<x.size();++i)x[i]=1ll*x[i]*y%mod;
return x;
}
vi Mul(vi x,vi y,ci sz){
int tl=x.size()+y.size()-1,lth=1;
while(lth<tl)lth<<=1;
x.resize(lth),y.resize(lth),NTT(x,0),NTT(y,0);
for(int i=0;i<lth;++i)x[i]=1ll*x[i]*y[i]%mod;
NTT(x,1),x.resize(sz);
return x;
}
vi Inv(vi x){
if(x.size()==1)return x[0]=POW(x[0],mod-2),x;
vi tmp=x;
int ts=x.size(),sz=(ts+1>>1);
tmp.resize(sz),tmp=Inv(tmp);
int tl=ts+tmp.size()+tmp.size()-2,lth=1;
while(lth<tl)lth<<=1;
x.resize(lth),tmp.resize(lth),NTT(x,0),NTT(tmp,0);
for(int i=0;i<lth;++i)tmp[i]=(2-1ll*x[i]*tmp[i])%mod*tmp[i]%mod,Red(tmp[i]);
NTT(tmp,1);
return tmp.resize(ts),tmp;
}
vi Ln(vi x){
vi tmp=x;
for(int i=1;i<tmp.size();++i)tmp[i-1]=1ll*i*tmp[i]%mod;
tmp.pop_back(),tmp=Mul(tmp,Inv(x),x.size());
for(int i=x.size()-1;i>0;--i)tmp[i]=1ll*tmp[i-1]*iv[i]%mod;
tmp[0]=0;
return tmp;
}
vi Exp(vi x){
if(x.size()==1)return x[0]=1,x;
int sz=(x.size()+1>>1);
vi tmp=x,t2;
tmp.resize(sz),tmp=Exp(tmp),t2=tmp;
t2.resize(x.size());
return Mul(tmp,Plus(Minus(Poly(1),Ln(t2)),x),x.size());
}
vi POW(vi x,ci y,ci yc){
int pw=0,in,ls,ns=x.size();
while(pw<ns&&!x[pw])++pw;
if(pw==ns)return x;
if(1ll*pw*y>=ns){
vi ret;
return ret.resize(ns),ret;
}
for(int i=pw;i<x.size();++i)x[i-pw]=x[i];
x.resize(ns-pw);
in=POW(ls=x[0],mod-2),x=Mul(x,in);
vi tmp=Exp(Mul(Ln(x),y));
ls=POW(ls,yc),tmp=Mul(tmp,ls);
vi ret;ret.resize(ns);
for(int i=0;i+pw*y<ns;++i)ret[i+pw*y]=tmp[i];
return ret;
}
/*vi val[LM<<1],vl;
int cnt;
void Solve(ci l,ci r){
++cnt,val[cnt].resize(0);
if(l>r)return(void)(val[cnt].push_back(1));
if(l==r)return(void)(val[cnt].push_back(mod-l),val[cnt].push_back(1));
int id=cnt,mid=l+r>>1,lc=cnt+1,rc;
Solve(l,mid),rc=cnt+1,Solve(mid+1,r);
val[id]=Mul(val[lc],val[rc],val[lc].size()+val[rc].size()-1);
}
vi Calc(vi x){
vl=x;
cnt=0,Solve(0,x.size()-1);
}*/
void init(ci x){
lm=1;
while(lm<x)lm<<=1;
for(int i=2;i<=lm;++i)lg[i]=lg[i>>1]+1,lg[i]!=lg[i-1]?rt[i][0]=POW(3,(mod-1)/i),rt[i][1]=POW(rt[i][0],mod-2):0;
for(int i=1;i<=lm;++i)iv[i]=POW(i,mod-2);
}
}