Part1:FFT(fast fast tle)
前置知识:复数,单位根,多项式的系数表达法,多项式的点值表达法
- 复数:
可以表示为((a+bi)),可以看做原点到((a,b))一个向量,其中(i=sqrt{-1})。
复数可以进行加,减,乘(向量的除法有点问题),其中
即:
同时复数的乘还有和向量一样的几何意义:模长相乘,幅角相加
- 单位根:
在OI中,经常用到2的正整数次幂相关的数,因为这样方便处理,为方便,我们规定下文的(n)为2的正整数次幂。
定义:如果(w_n^n=1)那么(w_n)为(n)次单位根
因为(w_n^n=1),根据复数乘的几何意义,可知模长为1,幅角为(frac{2pi}{n}),易得单位根
然后(w^k_n)的幅角为(frac{2kpi}{n}),所以
单位根这里还需要两个性质:
性质一:
即
性质二:
即
- 多项式的系数表达法:
就是平时的表达方法,用(n+1)个系数表示一个(n)次多项式,比如:
该方法易读,也易求值,但很难快速求卷积。
- 多项式的点值表达法:
就是用(n+1)个点来表示一个(n)次多项式,比如:
该方法不易理解,但很容易求卷积。
如果两个多项式(f,g)满足(fx_0==gx_0,fx_1==gx_1,...,fx_n==gx_n),则新多项式
正题:快速傅里叶变换(FFT)以及快速傅里叶逆变换(IFFT)
从上面两种多项式的表达方式中,我们可以发现如果能快速的把多项式在系数与点值中转换,就可以快速的获取两个多项式的卷积。
- 1、系数多项式转点值多项式(快速傅里叶变换)
给出多项式
我们需要快速求出(f(1),f(w_n),...,f(w^{n-1}_{n-1}))
先将(f)按奇偶分类分为
我们设
那么有
带入(x=w^k_n),
带入(x=w^{k+frac{n}{2}}_n)
根据性质二,有
可以发现(f(w^k_n))和(f(w^{k+frac{n}{2}}_n))两项只差第二项的系数,所以我们可以只用一半的时间就处理整个多项式。
可以发现,如果我们递归的处理,就可以用(O(nlog n))的复杂度实现FFT。
- 2、点值多项式转系数多项式(快速傅里叶逆变换)
我们可以把式子列成一个矩阵(没学过矩阵可以先学再看或跳过推导):
其中((a_0,a_1,a_2,...,a_{n-1}))为系数表达法,((y_0,y_1,y_2,...,y_{n-1}))为省略(x)的点值表达法。
如果我们能快速求出左边这个矩阵的逆矩阵,我们就能快速转换。
考虑矩阵求逆((O(n^3))完全负优化)
但我们可以发现原矩阵中所有数之间是有关联的,我们可以考虑转换。
设(V)为原矩阵,(G)为逆矩阵,考虑最终矩阵(E)在((i,j))上的值:
因为(V)和(G)互逆,所以(E)是单位矩阵,只有当(i=j)时才会有值1,否则为0。
我们先证明一个引理:当(k)不是(n)的倍数时
由等比数列求和得
因为(k)不是(n)的倍数,所以(w_n^k ot=1),即分母不为0,所以该引理成立。
根据这个引理,可以发现矩阵(G)有一个比较简单的构造方式,即(G(i,j)=w_n^{-ij})
这时
当(i-j)不为(n)的倍数(不为0时),(E(i,j)=1),但当(i=j)时,已知(E(i,j)=n),跟单位矩阵有点偏差,我们在前面加一个(frac{1}{n})。
好吧,这个推导其实有些牵强,只用把他当做结论记就可以了。
这样,我们就有:
这样我们就可以用类似系数转点值的方法转换了,只是这边的单位根要取反,其实在使用起来时就是
非常简单,只用在FFT的基础上略作修改即可。
最后因为直接乘出来的答案是真实值的(n)倍,所以要除以(n)。
然后我们就可以写出最基本的递归版FFT(虽然C++有自带complex类型,但用起来会比较慢,建议手写一个):
#include<bits/stdc++.h>
using namespace std;
const int N=1000010;
const double pi=acos(-1);
int n,m,lg[N<<1];
struct Complex{
double x,y;
Complex(double x=0,double y=0):x(x),y(y){}
};
Complex operator+(Complex a,Complex b){return Complex(a.x+b.x,a.y+b.y);}
Complex operator-(Complex a,Complex b){return Complex(a.x-b.x,a.y-b.y);}
Complex operator*(Complex a,Complex b){return Complex(a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x);}
Complex c[N<<2],b[N<<2],t[N<<2],s[50];
void fft(Complex *f,int l,int len,int op){
if(!len)return;
for(int i=l;i<l+(len<<1);++i)t[i]=f[i];
for(int i=l;i<l+len;++i)f[i]=t[l+((i-l)<<1)],f[i+len]=t[l+((i-l)<<1|1)];
fft(f,l,len>>1,op);
fft(f,l+len,len>>1,op);
Complex tmp=s[lg[len]],buf=Complex(1,0),d;
tmp.y*=op;
for(int i=l;i<l+len;++i){
d=buf*f[i+len];
t[i]=f[i]+d;
t[i+len]=f[i]-d;
buf=buf*tmp;
}
for(int i=l;i<l+(len<<1);++i)f[i]=t[i];
}
int main(){
int n,m;
cin>>n>>m;
for(int i=2;i<=n+m;++i)lg[i]=lg[i>>1]+1;
for(int i=0;i<=n;++i)scanf("%lf",&b[i].x);
for(int i=0;i<=m;++i)scanf("%lf",&c[i].x);
for(m+=n,n=1;n<=m;n<<=1);
for(int i=1,j=0;i<=n;i<<=1,++j)s[j]=Complex(cos(pi/i),sin(pi/i));
fft(b,0,n>>1,1);
fft(c,0,n>>1,1);
for(int i=0;i<n;++i)b[i]=b[i]*c[i];
fft(b,0,n>>1,-1);
for(int i=0;i<=m;++i)printf("%.0lf ",fabs(b[i].x)/n);
}
但递归版的着实很慢,这时就要用到神奇的:
二进制反转
因为开始的时候我们把所有数按奇偶性分类,所以我们不能直接枚举区间长度然后处理。但正是因为我们按奇偶分类,我们可以直接按二进制分类,然后一个数的真实位置就是该数的下标按二进制反转后的值。
求真实位置的方法其实很简单:
for(int i=0;i<limit;++i)r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
一个数的二进制反转为这个数除2的反转除2(得到后(l-1)位),如果这个数第一位为1,那么把最后一位加1。
然后就可以用非递归的写法优化递归的大常数:
#include<bits/stdc++.h>
#define eps 1e-6
using namespace std;
const int N=4e6+10;
const double pi=acos(-1);
struct Complex{
double x,y;
Complex(double x=0,double y=0):x(x),y(y){}
};
Complex operator+(Complex a,Complex b){return Complex(a.x+b.x,a.y+b.y);}
Complex operator-(Complex a,Complex b){return Complex(a.x-b.x,a.y-b.y);}
Complex operator*(Complex a,Complex b){return Complex(a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x);}
int limit,l;
int r[N];
void FFT(Complex*f,int op){
for(int i=0;i<limit;++i){
if(i<r[i])swap(f[i],f[r[i]]);
}
for(int mid=2;mid<=limit;mid*=2){
Complex wn=Complex(cos(2*pi/mid),op*sin(2*pi/mid));
for(int j=0;j<limit;j+=mid){
Complex w=Complex(1,0);
for(int k=j;k<j+mid/2;++k,w=w*wn){
Complex x=f[k],y=w*f[k+mid/2];
f[k]=x+y;
f[k+mid/2]=x-y;
}
}
}
if(op==-1){
for(int i=0;i<limit;++i){
f[i].x/=limit;
}
}
}
int n,m;
Complex a[N],b[N];
int main(){
cin>>n>>m;
for(int i=0;i<=n;++i)scanf("%lf",&a[i].x);
for(int i=0;i<=m;++i)scanf("%lf",&b[i].x);
for(limit=1,l=0;limit<=n+m;limit*=2,l++);
for(int i=0;i<limit;++i)r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
FFT(a,1),FFT(b,1);
for(int i=0;i<limit;++i)a[i]=a[i]*b[i];
FFT(a,-1);
for(int i=0;i<=n+m;++i){
if(fabs(a[i].x)<eps)printf("0 ");
else printf("%.0lf ",a[i].x);
}
}
例题:[ZJOI2014]力
Part2:NTT(快速数论变换)
在FFT的时候,因为会用到大量sin和cos以及double的乘法,会让精度有巨大的损失,所以我们要用一些其他的方法来让精度不损失。这时就要用到快速数论变换。
前置知识:原根
- 原根
原根:是一个数学符号。设(m)是正整数,(a)是整数,若(a)模(m)的阶等于(varphi(m)),则称(a)为模(m)的一个原根。
阶:使(a^nequiv 1(mod p))成立的最小正整数(n)叫做(a)模(p)的阶。这里(equiv)指同余符号,代表(a^n)除以(p)的余数跟1除以(p)的余数相等。
一般情况下模数为998244353,而998244353的原根为3。因为有:
而且对于(1leq i< 998244352),没有(3^iequiv 1(mod 998244353))
下文中我们默认(p=998244353,G=3)((p)为模数,(G)为原根)
然后我们要尝试用原根相关的东西替换掉单位根:
我们需要知道(w_n)怎么求:
因为(w_n^nequiv 1(mod p))
所以(w_n^nequiv G^{p-1}(mod p)),
所以有(w_nequiv G^{frac{p-1}{n}}(mod p))
先证明几个性质,
性质一:(w^{k+frac{n}{2}}=-w^k)
性质二:(w^{2k}_{2n}=w^k_n)
这两个性质的证明方法和FFT一致
所以就跟FFT完全一样了
#include<bits/stdc++.h>
using namespace std;
const int N=3e6+10,p=998244353,G=3,Gi=332748118;
int add(int x,int y){return(x+y)%p;}
int mul(int x,int y){return 1ll*x*y%p;}
int mpow(int a,int n){
int ret=1;
while(n){
if(n&1)ret=mul(ret,a);
a=mul(a,a);
n/=2;
}
return ret;
}
int n,m;
int a[N],b[N];
int r[N],limit;
void ntt(int*f,int op){
for(int i=0;i<limit;++i)if(i<r[i])swap(f[i],f[r[i]]);
for(int len=2;len<=limit;len*=2){
int wn=mpow(op==1?G:Gi,(p-1)/len);
for(int j=0;j<limit;j+=len){
int w=1;
for(int k=j;k<j+len/2;++k,w=mul(w,wn)){
int x=f[k],y=mul(w,f[k+len/2]);
f[k]=add(x,y);
f[k+len/2]=add(x,p-y);
}
}
}
if(op==-1){
int inv=mpow(limit,p-2);
for(int i=0;i<limit;++i){
f[i]=mul(f[i],inv);
}
}
}
int main(){
cin>>n>>m;
for(int i=0;i<=n;++i)scanf("%d",a+i);
for(int i=0;i<=m;++i)scanf("%d",b+i);
int l=0;limit=1;
while(limit<=n+m)limit*=2,l++;
for(int i=1;i<limit;++i)r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
ntt(a,1);
ntt(b,1);
for(int i=0;i<limit;++i)a[i]=mul(a[i],b[i]);
ntt(a,-1);
for(int i=0;i<=n+m;++i)printf("%d ",a[i]);
}
Part3:多项式求逆
我们要求
现在已经知道了:
然后可以转化:
根据这个理论基础,我们可以做出多项式求逆:
#include<bits/stdc++.h>
#define res register int
using namespace std;
const int N=3e6+10,p=998244353,G=3,Gi=332748118;
inline int add(res x,res y){return(x+y)>=p?x+y-p:x+y;}
inline int mul(res x,res y){return 1ll*x*y-1ll*x*y/p*p;}
inline int mpow(res a,res n){
res ret=1;
while(n){
if(n&1)ret=mul(ret,a);
a=mul(a,a);
n/=2;
}
return ret;
}
int g[2][N];
void init(){
for(int i=1;i<N;i*=2){
g[0][i]=mpow(G,(p-1)/i);
g[1][i]=mpow(Gi,(p-1)/i);
}
}
int n,m;
int ls[5][N];
int a[N],b[N],c[N];
int r[N];
void ntt(int*f,int limit,int op){
for(res i=0;i<limit;++i)if(i<r[i])swap(f[i],f[r[i]]);
for(res len=2;len<=limit;len*=2){
res wn=op==1?g[0][len]:g[1][len];
for(res j=0;j<limit;j+=len){
res w=1;
for(res k=j;k<j+len/2;++k,w=mul(w,wn)){
res x=f[k],y=mul(w,f[k+len/2]);
f[k]=add(x,y);
f[k+len/2]=add(x,p-y);
}
}
}
if(op==-1){
res inv=mpow(limit,p-2);
for(res i=0;i<limit;++i){
f[i]=mul(f[i],inv);
}
}
}
void mul(int*a,int*b,int*c,int n,int m){
int limit=1;
while(limit<n+m-1)limit*=2;
for(res i=1;i<limit;++i)r[i]=(r[i>>1]>>1)|((i&1)*limit/2);
for(int i=0;i<limit;++i)ls[0][i]=a[i],ls[1][i]=b[i];
ntt(ls[0],limit,1);
ntt(ls[1],limit,1);
for(res i=0;i<limit;++i)c[i]=mul(ls[0][i],ls[1][i]);
ntt(c,limit,-1);
}
void inv(int*a,int*b,int n){
b[0]=mpow(a[0],p-2);
for(int len=1,l=0,limit;len<2*n;len*=2){
limit=len*2,l++;
for(int i=1;i<limit;++i)r[i]=(r[i>>1]>>1)|((i&1)*limit/2);
for(int i=0;i<len;++i)ls[0][i]=a[i],ls[1][i]=b[i];
for(int i=len;i<limit;++i)ls[0][i]=ls[1][i]=0;
ntt(ls[0],limit,1),ntt(ls[1],limit,1);
for(int i=0;i<limit;++i){
b[i]=add(mul(2,ls[1][i]),p-(mul(ls[0][i],mul(ls[1][i],ls[1][i]))));
}
ntt(b,limit,-1);
for(int i=len;i<limit;++i)b[i]=0;
}
}
inline int read(){
res ret=0;char c;
for(c=getchar();!isdigit(c);c=getchar());
for(;isdigit(c);ret=ret*10+c-'0',c=getchar());
return ret;
}
int main(){
init();
cin>>n;
for(res i=0;i<n;++i)a[i]=read();
inv(a,c,n);
for(res i=0;i<n;++i)printf("%d ",c[i]);puts("");
}
Part 4:多项式ln
我们要求
可推导:
#include<bits/stdc++.h>
#define res register int
using namespace std;
const int N=3e6+10,p=998244353,G=3,Gi=332748118;
inline int add(res x,res y){return(x+y)>=p?x+y-p:x+y;}
inline int mul(res x,res y){return 1ll*x*y-1ll*x*y/p*p;}
inline int mpow(res a,res n){
res ret=1;
while(n){
if(n&1)ret=mul(ret,a);
a=mul(a,a);
n/=2;
}
return ret;
}
int g[2][N];
void init(){
for(int i=1;i<N;i*=2){
g[0][i]=mpow(G,(p-1)/i);
g[1][i]=mpow(Gi,(p-1)/i);
}
}
int n,m;
int ls[5][N],used;
//0,1 mul
//2,3,4 ln
int a[N],b[N],c[N];
int r[N];
void ntt(int*f,int limit,int op){
for(res i=0;i<limit;++i)if(i<r[i])swap(f[i],f[r[i]]);
for(res len=2;len<=limit;len*=2){
res wn=op==1?g[0][len]:g[1][len];
for(res j=0;j<limit;j+=len){
res w=1;
for(res k=j;k<j+len/2;++k,w=mul(w,wn)){
res x=f[k],y=mul(w,f[k+len/2]);
f[k]=add(x,y);
f[k+len/2]=add(x,p-y);
}
}
}
if(op==-1){
res inv=mpow(limit,p-2);
for(res i=0;i<limit;++i){
f[i]=mul(f[i],inv);
}
}
}
void mul(int*a,int*b,int*c,int n,int m){
int limit=1;
while(limit<n+m-1)limit*=2;
for(res i=1;i<limit;++i)r[i]=(r[i>>1]>>1)|((i&1)*limit/2);
for(int i=0;i<limit;++i)ls[0][i]=a[i],ls[1][i]=b[i];
ntt(ls[0],limit,1);
ntt(ls[1],limit,1);
for(res i=0;i<limit;++i)c[i]=mul(ls[0][i],ls[1][i]);
ntt(c,limit,-1);
}
void inv(int*a,int*b,int n){
b[0]=mpow(a[0],p-2);
for(int len=1,l=0,limit;len<2*n;len*=2){
limit=len*2,l++;
for(int i=1;i<limit;++i)r[i]=(r[i>>1]>>1)|((i&1)*limit/2);
for(int i=0;i<len;++i)ls[0][i]=a[i],ls[1][i]=b[i];
for(int i=len;i<limit;++i)ls[0][i]=ls[1][i]=0;
ntt(ls[0],limit,1),ntt(ls[1],limit,1);
for(int i=0;i<limit;++i){
b[i]=add(mul(2,ls[1][i]),p-(mul(ls[0][i],mul(ls[1][i],ls[1][i]))));
}
ntt(b,limit,-1);
for(int i=len;i<limit;++i)b[i]=0;
}
}
void direv(int*a,int*b,int n){
for(int i=1;i<n;++i){
b[i-1]=mul(a[i],i);
}
}
void inter(int*a,int*b,int n){
b[0]=0;
for(int i=1;i<n;++i){
b[i]=mul(a[i-1],mpow(i,p-2));
}
}
void ln(int*a,int*b,int n){
direv(a,ls[2],n);
inv(a,ls[3],n);
mul(ls[2],ls[3],ls[4],n,n);
inter(ls[4],b,2*n);
for(int i=n;i<2*n;++i)b[i]=0;
}
void sqrt(int*a,int*b,int n){
b[0]=1;
}
inline int read(){
res ret=0;char c;
for(c=getchar();!isdigit(c);c=getchar());
for(;isdigit(c);ret=ret*10+c-'0',c=getchar());
return ret;
}
int main(){
init();
cin>>n;
for(res i=0;i<n;++i)a[i]=read();
ln(a,c,n);
for(res i=0;i<n;++i)printf("%d ",c[i]);puts("");
}
Part ex:黑科技
突然学会了一个黑科技:
牛顿迭代,泰勒展开
设(F,F_0)是多项式,(G)是某个多项式函数。
现在要求
现在已经知道了
我们对(G(F))在(F_0)处泰勒展开
因为(F)和(F_0)的前(frac n2)相同,所以((F-F_0))的前(frac n2)为0,所以对于(n>1)的情况((F-F_0)^n)的前n为必定为0,对答案无意义,可舍去。
所以有
因为(G(F)equiv 0(mod x^n)),所以有
这里要注意当求(G'(F))时,我们要把(F)当成一个未知数,这样(G'(F)=G'F)
Part 4:多项式exp
用黑科技可求解。
给出多项式(A)
#include<bits/stdc++.h>
#define res register int
using namespace std;
const int N=3e6+10,p=998244353,G=3,Gi=332748118;
inline int add(res x,res y){return(x+y)>=p?x+y-p:x+y;}
inline int mul(res x,res y){return 1ll*x*y-1ll*x*y/p*p;}
inline int mpow(res a,res n){
res ret=1;
while(n){
if(n&1)ret=mul(ret,a);
a=mul(a,a);
n/=2;
}
return ret;
}
int g[2][N];
void init(){
for(int i=1;i<N;i*=2){
g[0][i]=mpow(G,(p-1)/i);
g[1][i]=mpow(Gi,(p-1)/i);
}
}
int n,m;
int ls[10][N],used;
//0,1 mul
//2,3,4 ln
//5,6 exp
int a[N],b[N],c[N];
int r[N];
void ntt(int*f,int limit,int op){
for(res i=0;i<limit;++i)if(i<r[i])swap(f[i],f[r[i]]);
for(res len=2;len<=limit;len*=2){
res wn=op==1?g[0][len]:g[1][len];
for(res j=0;j<limit;j+=len){
res w=1;
for(res k=j;k<j+len/2;++k,w=mul(w,wn)){
res x=f[k],y=mul(w,f[k+len/2]);
f[k]=add(x,y);
f[k+len/2]=add(x,p-y);
}
}
}
if(op==-1){
res inv=mpow(limit,p-2);
for(res i=0;i<limit;++i){
f[i]=mul(f[i],inv);
}
}
}
void mul(int*a,int*b,int*c,int n,int m){
int limit=1;
while(limit<n+m-1)limit*=2;
for(int i=0;i<limit;++i)ls[0][i]=ls[1][i]=0;
for(res i=1;i<limit;++i)r[i]=(r[i>>1]>>1)|((i&1)*limit/2);
for(int i=0;i<limit;++i)ls[0][i]=a[i],ls[1][i]=b[i];
ntt(ls[0],limit,1);
ntt(ls[1],limit,1);
for(res i=0;i<limit;++i)c[i]=mul(ls[0][i],ls[1][i]);
ntt(c,limit,-1);
}
void inv(int*a,int*b,int n){
b[0]=mpow(a[0],p-2);
for(int len=1,limit;len<2*n;len*=2){
limit=len*2;
for(int i=1;i<limit;++i)r[i]=(r[i>>1]>>1)|((i&1)*limit/2);
for(int i=0;i<len;++i)ls[0][i]=a[i],ls[1][i]=b[i];
for(int i=len;i<limit;++i)ls[0][i]=ls[1][i]=0;
ntt(ls[0],limit,1),ntt(ls[1],limit,1);
for(int i=0;i<limit;++i){
b[i]=add(mul(2,ls[1][i]),p-(mul(ls[0][i],mul(ls[1][i],ls[1][i]))));
}
ntt(b,limit,-1);
for(int i=len;i<limit;++i)b[i]=0;
}
}
void direv(int*a,int*b,int n){
for(int i=1;i<n;++i){
b[i-1]=mul(a[i],i);
}
}
void inter(int*a,int*b,int n){
b[0]=0;
for(int i=1;i<n;++i){
b[i]=mul(a[i-1],mpow(i,p-2));
}
}
void ln(int*a,int*b,int n){
direv(a,ls[2],n);
inv(a,ls[3],n);
mul(ls[2],ls[3],ls[4],n,n);
inter(ls[4],b,2*n);
for(int i=n;i<2*n;++i)b[i]=0;
}
void exp(int*a,int*b,int n){
b[0]=1;
for(int len=1;len<2*n;len*=2){
int limit=len*2;
ln(b,ls[5],len);
for(int i=0;i<len;++i){
ls[5][i]=add(p-ls[5][i],a[i]);
}
ls[5][0]=add(ls[5][0],1);
for(int i=0;i<len;++i)ls[6][i]=b[i];
mul(ls[5],ls[6],b,len,len);
for(int i=len;i<limit;++i)b[i]=0;
}
}
inline int read(){
res ret=0;char c;
for(c=getchar();!isdigit(c);c=getchar());
for(;isdigit(c);ret=ret*10+c-'0',c=getchar());
return ret;
}
int main(){
init();
cin>>n;
for(res i=0;i<n;++i)a[i]=read();
exp(a,c,n);
for(res i=0;i<n;++i)printf("%d ",c[i]);puts("");
}
Part 5:多项式开根
也使用上文黑科技:
给出多项式(A)
#include<bits/stdc++.h>
#define res register int
#define ll long long
ll js;
using namespace std;
const int N=3e6+10,p=998244353,G=3,Gi=332748118;
inline int add(res x,res y){return(x+y)%p;}
inline int mul(res x,res y){return 1ll*x*y%p;}
inline int mpow(res a,res n){
res ret=1;
while(n){
if(n&1)ret=mul(ret,a);
a=mul(a,a);
n/=2;
}
return ret;
}
int g[2][N];
int inv2;
void init(){
inv2=mpow(2,p-2);
for(int i=1;i<N;i*=2){
g[0][i]=mpow(G,(p-1)/i);
g[1][i]=mpow(Gi,(p-1)/i);
}
}
int n,m;
int ls[7][N],used;
//0,1 mul inv
//2,3,4 ln sqrt
//5,6 exp
int a[N],b[N],c[N];
int r[N];
void ntt(int*f,res limit,res op){
for(res i=0;i<limit;++i)if(i<r[i])swap(f[i],f[r[i]]);
for(res len=2;len<=limit;len*=2){
res wn=op==1?g[0][len]:g[1][len];
for(res j=0;j<limit;j+=len){
res w=1;
for(res k=j;k<j+len/2;++k,w=mul(w,wn)){
res x=f[k],y=mul(w,f[k+len/2]);
f[k]=add(x,y);
f[k+len/2]=add(x,p-y);
}
}
}
if(op==-1){
res inv=mpow(limit,p-2);
for(res i=0;i<limit;++i){
f[i]=mul(f[i],inv);
}
}
}
void mul(int*a,int*b,int*c,int n,int m){
int limit=1;
while(limit<n+m-1)limit*=2;
for(res i=0;i<limit;++i)ls[0][i]=ls[1][i]=0;
for(res i=1;i<limit;++i)r[i]=(r[i>>1]>>1)|((i&1)*limit/2);
for(res i=0;i<limit;++i)ls[0][i]=a[i],ls[1][i]=b[i];
ntt(ls[0],limit,1);
ntt(ls[1],limit,1);
for(res i=0;i<limit;++i)c[i]=mul(ls[0][i],ls[1][i]);
ntt(c,limit,-1);
}
void inv(int*a,int*b,int n){
b[0]=mpow(a[0],p-2);
for(res len=2,limit;len<2*n;len*=2){
limit=len*2;
for(res i=1;i<limit;++i)r[i]=(r[i>>1]>>1)|((i&1)*limit/2);
for(res i=0;i<len;++i)ls[0][i]=a[i],ls[1][i]=b[i];
for(res i=len;i<limit;++i)ls[0][i]=ls[1][i]=0;
ntt(ls[0],limit,1),ntt(ls[1],limit,1);
for(res i=0;i<limit;++i){
b[i]=mul(add(2,p-(mul(ls[0][i],ls[1][i]))),ls[1][i]);
}
ntt(b,limit,-1);
for(res i=len;i<limit;++i)b[i]=0;
}
}
inline void direv(int*a,int*b,int n){
for(res i=1;i<n;++i){
b[i-1]=mul(a[i],i);
}
}
inline void inter(int*a,int*b,int n){
b[0]=0;
for(res i=1;i<n;++i){
b[i]=mul(a[i-1],mpow(i,p-2));
}
}
void ln(int*a,int*b,int n){
direv(a,ls[2],n);
inv(a,ls[3],n);
mul(ls[2],ls[3],ls[4],n,n);
inter(ls[4],b,2*n);
for(res i=n;i<2*n;++i)b[i]=0;
}
void exp(int*a,int*b,int n){
b[0]=1;
for(res len=2;len<2*n;len*=2){
res limit=len*2;
ln(b,ls[5],len);
for(res i=0;i<len;++i){
ls[5][i]=add(p-ls[5][i],a[i]);
}
ls[5][0]=add(ls[5][0],1);
for(res i=0;i<len;++i)ls[6][i]=b[i];
mul(ls[5],ls[6],b,len,len);
for(res i=len;i<limit;++i)b[i]=0;
}
}
void sqrt(int*a,int*b,int n){
b[0]=1;
for(res len=2;len<2*n;len*=2){
res limit=len*2;
inv(b,ls[2],len);
for(res i=0;i<len;++i)ls[3][i]=a[i];
mul(ls[2],ls[3],ls[4],len,len);
for(res i=0;i<len;++i)b[i]=mul(add(b[i],ls[4][i]),inv2);
for(res i=len;i<limit;++i)b[i]=0;
}
}
inline int read(){
res ret=0;char c;
for(c=getchar();!isdigit(c);c=getchar());
for(;isdigit(c);ret=ret*10+c-'0',c=getchar());
return ret;
}
int main(){
init();
cin>>n;
for(res i=0;i<n;++i)a[i]=read();
sqrt(a,c,n);
for(res i=0;i<n;++i)printf("%d ",c[i]);puts("");
}
Part 6:多项式快速幂
直接换底公式即可
#include<bits/stdc++.h>
#define res register int
using namespace std;
const int N=3e6+10,p=998244353,G=3,Gi=332748118;
inline int add(res x,res y){return(x+y)%p;}
inline int mul(res x,res y){return 1ll*x*y%p;}
inline int mpow(res a,res n){
res ret=1;
while(n){
if(n&1)ret=mul(ret,a);
a=mul(a,a);
n/=2;
}
return ret;
}
int g[2][N];
int inv2;
void init(){
inv2=mpow(2,p-2);
for(int i=1;i<N;i*=2){
g[0][i]=mpow(G,(p-1)/i);
g[1][i]=mpow(Gi,(p-1)/i);
}
}
int n,m;
int ls[7][N],used;
//0,1 mul inv
//2,3,4 ln sqrt
//5,6 exp
int a[N],b[N],c[N];
int r[N];
void ntt(int*f,res limit,res op){
for(res i=0;i<limit;++i)if(i<r[i])swap(f[i],f[r[i]]);
for(res len=1;len<=limit;len*=2){
res wn=op==1?g[0][len]:g[1][len];
for(res j=0;j<limit;j+=len){
res w=1;
for(res k=j;k<j+len/2;++k,w=mul(w,wn)){
res x=f[k],y=mul(w,f[k+len/2]);
f[k]=add(x,y);
f[k+len/2]=add(x,p-y);
}
}
}
if(op==-1){
res inv=mpow(limit,p-2);
for(res i=0;i<limit;++i){
f[i]=mul(f[i],inv);
}
}
}
void mul(int*a,int*b,int*c,int n,int m){
int limit=1;
while(limit<n+m-1)limit*=2;
for(res i=0;i<limit;++i)ls[0][i]=ls[1][i]=0;
for(res i=1;i<limit;++i)r[i]=(r[i>>1]>>1)|((i&1)*limit/2);
for(res i=0;i<limit;++i)ls[0][i]=a[i],ls[1][i]=b[i];
ntt(ls[0],limit,1);
ntt(ls[1],limit,1);
for(res i=0;i<limit;++i)c[i]=mul(ls[0][i],ls[1][i]);
ntt(c,limit,-1);
}
void inv(int*a,int*b,int n){
b[0]=mpow(a[0],p-2);
res limit;
for(res len=1;len<2*n;len*=2){
limit=len*2;
for(res i=1;i<limit;++i)r[i]=(r[i>>1]>>1)|((i&1)*limit/2);
for(res i=0;i<len;++i)ls[0][i]=a[i],ls[1][i]=b[i];
for(res i=len;i<limit;++i)ls[0][i]=ls[1][i]=0;
ntt(ls[0],limit,1),ntt(ls[1],limit,1);
for(res i=0;i<limit;++i){
b[i]=mul(add(2,p-(mul(ls[0][i],ls[1][i]))),ls[1][i]);
}
ntt(b,limit,-1);
for(res i=len;i<limit;++i)b[i]=0;
}
for(int i=n;i<limit;++i)b[i]=0;
}
inline void direv(int*a,int*b,int n){
for(res i=1;i<n;++i){
b[i-1]=mul(a[i],i);
}
b[n-1]=0;
}
inline void inter(int*a,int*b,int n){
b[0]=0;
for(res i=1;i<n;++i){
b[i]=mul(a[i-1],mpow(i,p-2));
}
}
void ln(int*a,int*b,int n){
direv(a,ls[2],n);
inv(a,ls[3],n);
mul(ls[2],ls[3],ls[4],n,n);
inter(ls[4],b,2*n);
for(res i=n;i<2*n;++i)b[i]=0;
for(res i=0;i<n;++i)ls[2][i]=ls[3][i]=ls[4][i]=0;
}
void exp(int*a,int*b,int n){
b[0]=1;
for(res len=1;len<2*n;len*=2){
res limit=len*2;
ln(b,ls[5],len);
for(res i=0;i<len;++i){
ls[5][i]=add(p-ls[5][i],a[i]);
}
ls[5][0]=add(ls[5][0],1);
for(res i=0;i<len;++i)ls[6][i]=b[i];
mul(ls[5],ls[6],b,len,len);
for(res i=len;i<limit;++i)b[i]=0;
}
}
void sqrt(int*a,int*b,int n){
b[0]=1;
for(res len=2;len<2*n;len*=2){
res limit=len*2;
inv(b,ls[2],len);
for(res i=0;i<len;++i)ls[3][i]=a[i];
mul(ls[2],ls[3],ls[4],len,len);
for(res i=0;i<len;++i)b[i]=mul(add(b[i],ls[4][i]),inv2);
for(res i=len;i<limit;++i)b[i]=0;
}
}
inline int read(){
res ret=0;char c;
for(c=getchar();!isdigit(c);c=getchar());
for(;isdigit(c);ret=add(mul(ret,10),c-'0'),c=getchar());
return ret;
}
int main(){
init();
n=read(),m=read();
for(res i=0;i<n;++i)a[i]=read();
ln(a,b,n);
for(int i=0;i<n;++i)b[i]=mul(b[i],m);
exp(b,c,n);
for(res i=0;i<n;++i)printf("%d ",c[i]);puts("");
}