概述
多项式求逆元是一个非常重要的知识点,许多多项式操作都需要用到该算法,包括多项式取模,除法,开跟,求ln,求exp,快速幂。用快速傅里叶变换和倍增法可以在$O(n log n)$的时间复杂度下求出一个$n$次多项式的逆元。
前置技能
快速数论变换(NTT),求一个数$x$在模$p$意义下的乘法逆元。
多项式的逆元
给定一个多项式$A(x)$,其次数为$deg_A$,若存在一个多项式$B(x)$,使其满足$deg_B≤deg_A$,且$A(x) imes B(x) equiv 1 (mod x^n)$,则$B(x)$即为$A(x)$在模$x^n$意义下的的乘法逆元。
求多项式的逆元
我们不妨假设,$n=2^k,k∈N$。
若$n=1$,则$A(x) imes B(x) equiv a_0 imes b_0 equiv 1 (mod x^1)$。其中$a_0$,$b_0$表示多项式$A$和多项式$B$的常数项。
若需要求出$b_0$,直接用费马小定理求出$a_0$的乘法逆元即可。
当$n>1$时:
我们假设在模$x^{frac{n}{2}}$的意义下$A(x)$的逆元$B'(x)$我们已经求得。
依据定义,则有
$A(x)B'(x)equiv 1 (mod x^{frac{n}{2}})$ $(1)$
对$(1)$式进行移项得
$A(x)B'(x)-1equiv 0 (mod x^{frac{n}{2}})$ $(2)$
然后对$(2)$式等号两边平方,得
$A^2(x)B'^2(x)-2A(x)B'(x)+1equiv 0(mod x^{n})$ $(3)$
将常数项移动到等式右侧,得
$A^2(x)B'^2(x)-2A(x)B'(x)equiv -1(mod x^{n})$ $(4)$
将等式两边去相反数,得
$2A(x)B'(x)-A^2(x)B'^2(x)equiv 1(mod x^{n})$ $(5)$
下面考虑回我们需要求的多项式$B(x)$,依据定义,其满足
$A(x)B(x)equiv 1(mod x^{n}) $(6)$
将$(5)-(6)$并移项,得
$A(x)B(x)equiv 2A(x)B'(x)-A^2(x)B'^2(x)(mod x^{n})$ $(7)$
等式两边约去$A(x)$,得
$B(x)equiv 2B'(x)-A(x)B'^2(x)(mod x^{n})$ $(8)$
显然,我们可以用上述式子求出$B(x)$。
这一步的计算我们可以使用$NTT$,时间复杂度为$O(n log n)$。
我们可以通过递归的方法,求解出$B(x)$。
时间复杂度$T(n)=T(dfrac{n}{2})+O(n log n)=O(n log n)$。
洛谷上有一道题目就叫做多项式求逆元(点这里),可以先做下那一题。
模板如下:
1 #include<bits/stdc++.h> 2 #define M (1<<19) 3 #define L long long 4 #define MOD 998244353 5 #define G 3 6 using namespace std; 7 8 L pow_mod(L x,L k){ 9 L ans=1; 10 while(k){ 11 if(k&1) ans=ans*x%MOD; 12 x=x*x%MOD; k>>=1; 13 } 14 return ans; 15 } 16 17 void change(L a[],int n){ 18 for(int i=0,j=0;i<n-1;i++){ 19 if(i<j) swap(a[i],a[j]); 20 int k=n>>1; 21 while(j>=k) j-=k,k>>=1; 22 j+=k; 23 } 24 } 25 void NTT(L a[],int n,int on){ 26 change(a,n); 27 for(int h=2;h<=n;h<<=1){ 28 L wn=pow_mod(G,(MOD-1)/h); 29 for(int j=0;j<n;j+=h){ 30 L w=1; 31 for(int k=j;k<j+(h>>1);k++){ 32 L u=a[k],t=w*a[k+(h>>1)]%MOD; 33 a[k]=(u+t)%MOD; 34 a[k+(h>>1)]=(u-t+MOD)%MOD; 35 w=w*wn%MOD; 36 } 37 } 38 } 39 if(on==-1){ 40 L inv=pow_mod(n,MOD-2); 41 for(int i=0;i<n;i++) a[i]=a[i]*inv%MOD; 42 reverse(a+1,a+n); 43 } 44 } 45 46 void getinv(L a[],L b[],int n){ 47 if(n==1){b[0]=pow_mod(a[0],MOD-2); return;} 48 static L c[M],d[M]; 49 memset(c,0,n<<4); memset(d,0,n<<4); 50 getinv(a,c,n>>1); 51 for(int i=0;i<n;i++) d[i]=a[i]; 52 NTT(d,n<<1,1); NTT(c,n<<1,1); 53 for(int i=0;i<(n<<1);i++) b[i]=(2*c[i]-d[i]*c[i]%MOD*c[i]%MOD+MOD)%MOD; 54 NTT(b,n<<1,-1); 55 for(int i=0;i<n;i++) b[n+i]=0; 56 } 57 L a[M]={0},b[M]={0}; 58 int main(){ 59 int n,N; scanf("%d",&n); 60 for(int i=0;i<=n;i++) scanf("%lld",a+i); 61 for(N=1;N<=n;N<<=1); 62 getinv(a,b,N); 63 for(int i=0;i<=n;i++) printf("%lld ",b[i]); 64 }