前置芝士:
多项式的各种运算
这些运算都是在模意义下进行的运算,但多项式的取模运算与整数的取模运算有些不同。
多项式对 \(x^n\) 取模的意思是舍弃 \(x^n\) 以及更高次的部分。
多项式求逆
- 对于一个多项式 \(A(x)\) ,如果存在 \(B(x)\) 使得
\[A(x)B(x)\equiv 1\pmod {x^n} \]
- 那么称 \(B(x)\) 为 \(A(x)\) 在 \(mod\: x^n\) 意义下的逆元 \((inverse\:element)\),记作 \(A^{-1}(x)\)。
- 取模意义下,没有模数的逆元是没有意义的,因为不同的模数对应不一样的逆元。
推导
-
考虑用倍增法求解。
-
假如我们现在已经求出了 \(A(x)\) 在 \(mod\:x^{\frac{n}{2}}\) 意义下的逆元 \(B_0(x)\) ,即
\[A(x)B_0(x)\equiv 1\pmod {x^{\frac{n}{2}}} \] -
而
\[\because A(x)B(x)\equiv 1\pmod {x^{\frac{x}{2}}} \] -
两式相减并消去 \(A(x)\) 得
\[\therefore B(x)-B_0(x)\equiv 0\pmod {x^{\frac{n}{2}}} \] -
再同时平方
\[B^2(x)-2B(x)B_0(x)+B_0^2(x)\equiv 0\pmod {x^n}B^2(x)-2B(x)B_0(x)+B_0^2(x)\equiv 0\pmod {x^n} \] -
乘上 \(A(x)\),即可消去 \(B(x)\)
\[B(x)-2B_0(x)+A(x)B_0^2\equiv 0\pmod {x^n} \] -
所以得到递推式
\[B(x)=B_0(x)(2-A(x)B_0(x))\pmod {x^n} \] -
边界:当 \(n=1\) 时,\(B_0(x)\) 即为 \(A(x)\) 常数项的逆元。
-
然后就可以在 \(O(nlogn)\) 的时间复杂度内求逆啦
代码
递归版:
#include <iostream> //递归
#include <cstdio>
using namespace std;
const int maxn=4e5+10,mod=998244353,g=3,gn=332748118;
int p=1,bit,inver;
int f[maxn],h[maxn],c[maxn],rev[maxn];
inline int power(long long a,int x) {
long long ans=1;
while(x) {
if(x&1) ans=(ans*a)%mod;
a=(a*a)%mod;
x>>=1;
}
return ans;
}
inline void ntt(int *t,int len,int inv) {
for(int i=0;i<len;i++) {
if(i<rev[i]) swap(t[i],t[rev[i]]);
}
for(int mid=1;mid<len;mid<<=1) {
int unit=power(inv==1 ? g : gn,(mod-1)/(mid<<1));
int d=mid<<1;
for(int l=0;l<len;l+=d) {
int now=1;
for(int i=0;i<mid;i++) {
int x=t[l+i],y=(long long)now*t[l+mid+i]%mod;
t[l+i]=(x+y)%mod;
t[l+mid+i]=(x-y+mod)%mod;
now=(long long)now*unit%mod;
}
}
}
if(inv==-1) for(int i=0;i<len;i++) {
t[i]=(long long)t[i]*inver%mod;
}
}
inline void solve(int deg) {
if(deg==1) {
h[0]=power(f[0],mod-2);
return ;
}
solve((deg+1)>>1);
while(p<(deg<<1)) {p<<=1;bit++;}
inver=power(p,mod-2);
for(int i=1;i<p;i++) {
rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
}
for(int i=0;i<deg;i++) c[i]=f[i];
ntt(c,p,1),ntt(h,p,1);
for(int i=0;i<p;i++)
h[i]=h[i]*(2-(long long)c[i]*h[i]%mod+mod)%mod;
ntt(h,p,-1);
for(int i=deg;i<p;i++) h[i]=0; //必须要归零
}
int main() {
int n=read();
for(int i=0;i<n;i++) f[i]=read();
solve(n);
for(int i=0;i<n;i++) printf("%d ",h[i]);
putchar('\n');
return 0;
}
// by pycr
递推版:
#include <iostream> //递推
#include <cstdio>
using namespace std;
const int maxn=4e5+10,mod=998244353,g=3,gn=332748118;
int p=1,bit,inver;
int f[maxn],h[maxn],t[maxn],rev[maxn];
inline int power(long long a,int x) {
long long ans=1;
while(x) {
if(x&1) ans=(ans*a)%mod;
a=(a*a)%mod;
x>>=1;
}
return ans;
}
inline void ntt(int *t,int len,int inv) {
for(int i=0;i<len;i++) {
if(i<rev[i]) swap(t[i],t[rev[i]]);
}
for(int mid=1;mid<len;mid<<=1) {
int unit=power(inv==1 ? g : gn,(mod-1)/(mid<<1));
int d=mid<<1;
for(int l=0;l<len;l+=d) {
int now=1;
for(int i=0;i<mid;i++) {
int x=t[l+i],y=(long long)now*t[l+mid+i]%mod;
t[l+i]=(x+y)%mod;
t[l+mid+i]=(x-y+mod)%mod;
now=(long long)now*unit%mod;
}
}
}
if(inv==-1) for(int i=0;i<len;i++) {
t[i]=(long long)t[i]*inver%mod;
}
}
signed main() {
int n=read();
for(int i=0;i<n;i++) f[i]=read();
h[0]=power(f[0],mod-2);
for(int i=2;i<=(n<<1)-2;i<<=1) {
while(p<=2*i-3) {p<<=1,bit++;}
for(int j=1;j<p;j++) {
rev[j]=(rev[j>>1]>>1)|((j&1)<<(bit-1));
}
inver=power(p,mod-2);
for(int j=0;j<i;j++) t[j]=f[j];
ntt(t,p,1),ntt(h,p,1);
for(int j=0;j<p;j++) h[j]=h[j]*(2-(long long)t[j]*h[j]%mod+mod)%mod;
ntt(h,p,-1);
for(int j=i;j<p;j++) h[j]=0;
}
for(int i=0;i<n;i++) printf("%d ",h[i]);
putchar('\n');
return 0;
}
// by pycr
- 测出来都在 \(900ms\) 左右,相差 \(1ms\) ,
简直奇慢无比…… - \(Tips:\) 每一次递归(递推)结束后,都需要把 \(h\) 数组清零,不然会影响答案的正确性。
多项式对数函数
- 求 \(B(x)\equiv \ln\:A(x)\pmod {x^n}\)
推导
-
\(\ln\) 看着太碍眼了,有没有什么能够消除 \(\ln\) 的方法?
-
自然是有的,联系到我们之前学的微积分知识可以想到,用链规则对 \(\ln\:A(x)\) 求导可以得到 \(\frac{A'(x)}{A(x)}\) ,学过多项式的逆就很容易计算这个式子的答案了,最后对其积分就行。即:
\[\ln\:A(x)=\int \frac{A'(x)}{A(x)}dx \] -
\(Tips:\) 多项式常数项为 \(1\) 时才能取 \(\ln\) ,取后常数项为 \(0\)。
代码
#include <iostream>
#include <cstdio>
using namespace std;
const int maxn=4e5+10,mod=998244353,g=3,gn=(mod+1)/g;
int p=1,bit,inver;
int f[maxn],h[maxn],c[maxn],rev[maxn];
inline int power(long long a,int x) {
long long ans=1;
while(x) {
if(x&1) ans=(ans*a)%mod;
a=(a*a)%mod;
x>>=1;
}
return ans;
}
inline void ntt(int *t,int len,int inv) {
for(int i=0;i<len;i++) {
if(i<rev[i]) swap(t[i],t[rev[i]]);
}
for(int mid=1;mid<len;mid<<=1) {
int unit=power(inv==1 ? g : gn,(mod-1)/(mid<<1));
int d=mid<<1;
for(int l=0;l<len;l+=d) {
int now=1;
for(int i=0;i<mid;i++) {
int x=t[l+i],y=(long long)now*t[l+mid+i]%mod;
t[l+i]=(x+y)%mod;
t[l+mid+i]=(x-y+mod)%mod;
now=(long long)now*unit%mod;
}
}
}
if(inv==-1) for(int i=0;i<len;i++) {
t[i]=(long long)t[i]*inver%mod;
}
}
inline void getinv(int *f,int *h,int deg) {
if(deg==1) {
h[0]=power(f[0],mod-2);
return ;
}
getinv(f,h,(deg+1)>>1);
while(p<(deg<<1)) {p<<=1;bit++;}
inver=power(p,mod-2);
for(int i=1;i<p;i++)
rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
memcpy(c,f,deg*4);
ntt(h,p,1);ntt(c,p,1);
for(int i=0;i<p;i++)
h[i]=h[i]*(2-(long long)c[i]*h[i]%mod+mod)%mod;
ntt(h,p,-1);
for(int i=deg;i<p;i++) h[i]=0;
}
inline void derivative(int *t,int len) {
for(int i=1;i<len;i++) {
t[i-1]=(long long)i*t[i]%mod;
}
t[len-1]=0;
}
inline void integrate(int *t,int len) {
for(int i=len-1;i;i--) {
t[i]=(long long)t[i-1]*power(i,mod-2)%mod;
}
t[0]=0;
}
int main() {
int n=read();
for(int i=0;i<n;i++) f[i]=read();
getinv(f,h,n);
derivative(f,n);
ntt(f,p,1),ntt(h,p,1);
for(int i=0;i<p;i++) f[i]=(long long)f[i]*h[i]%mod;
ntt(f,p,-1);
integrate(f,n);
for(int i=0;i<n;i++) printf("%d ",f[i]);
putchar('\n');
return 0;
}
// by pycr
牛顿迭代
??怎么乱入啊?牛顿迭代也是多项式运算中比较重要的一部分。
-
多项式的牛顿迭代可不是用来在实数域和复数域上近似求解方程的。
-
其用来求解以下方程中的 \(B(x)\) :
\[G(B(x))\equiv 0\pmod {x^n} \] -
还是考虑倍增法:假设我们已经求出了 \(\frac{n}{2}\) 次多项式 \(B_0(x)\) 使得:
\[G(B_0(x))\equiv 0\pmod {x^{\frac{n}{2}}} \] -
结合之前泰勒展开的知识,将其在 \(B_0(x)\) 处泰勒展开:
\[\sum_{i=0}^{+\infty}\frac{G^{(i)}(B_0(x))}{i!}(B(x)-B_0(x))^i\equiv 0\pmod {x^n} \]因为 \(B(x)-B_0(x)\) 在 \(x^{\frac{n}{2}}\) 次项之下的系数都为 \(0\),所以其平方或者变成更高次幂之后在 \(mod\:x^n\) 意义下都为 \(0\),所以可以直接丢弃。
-
那么原式就变为
\[G(B(x))\equiv G(B_0(x))+G'(B_0(x))(B(x)-B_0(x))\pmod {x^n} \]因为 \(G(B(x))\equiv 0\pmod {x^n}\),得到
\[B(x)\equiv B_0(x)-\frac{G(B_0(x))}{G'(B_0(x))}\pmod {x^n} \] -
然后就可以愉快的递归(递推)啦。
考虑用牛顿迭代实现多项式求逆
-
其实很简单 -
设 \(G(B(x))=\frac{1}{B(x)}-A(x)\equiv 0\pmod {x^n}\)
-
则
\[\begin{aligned} B(x)&\equiv B_0(x)-\frac{G(B_0(x))}{G'(B_0(x))}&\pmod {x^n}\\ &\equiv B_0(x)-\frac{\frac{1}{B_0(x)}-A(x)}{-\frac{1}{B_0^2(x)}}&\pmod {x^n}\\ &\equiv 2\cdot B_0(x)-B_0^2A(x)&\pmod {x^n}\\ &\equiv B_0(x)(2-B_0(x)A(x))&\pmod {x^n} \end{aligned} \]
多项式指数函数
- 求 \(B(x)\equiv e^{A(x)}\pmod {x^n}\)
推导
-
这个需要用到牛顿迭代。
不然我之前讲迭代干嘛? -
考虑对两边同时取自然对数:
\[\ln B(x)\equiv A(x)\\ \] -
设函数 \(G(B(x))=\ln B(x)-A(x)\),套用牛顿迭代得:
\[\begin{aligned} B(x)&\equiv B_0(x)-\frac{G(B_0(x))}{G'(B_0(x))}&\pmod {x^n}\\ &\equiv B_0(x)-\frac{\ln B_0(x)-A(x)}{\frac{1}{B_0(x)}}&\pmod {x^n}\\ &\equiv B_0(x)(1-\ln B_0(x)-A(x))&\pmod {x^n} \end{aligned} \]结合之前的多项式对数函数即可。
-
\(Tips:\) 多项式常数项为 \(0\) 时才能取 \(\exp\) ,取后常数项为 \(1\) 。
代码
#include <iostream>
#include <cstdio>
#include <cstring>
using namespace std;
const int maxn=4e5+10,mod=998244353,g=3,gn=(mod+1)/g;
int p=1,bit,inver;
int f[maxn],h[maxn],c[maxn],rev[maxn];
int h_ln[maxn],c_e[maxn],f_inv[maxn];
inline int power(long long a,int x) {
long long ans=1;
while(x) {
if(x&1) ans=(ans*a)%mod;
a=(a*a)%mod;
x>>=1;
}
return ans;
}
inline void ntt(int *t,int len,int inv) {
for(int i=0;i<len;i++) {
if(i<rev[i]) swap(t[i],t[rev[i]]);
}
for(int mid=1;mid<len;mid<<=1) {
int unit=power(inv==1 ? g : gn,(mod-1)/(mid<<1));
int d=mid<<1;
for(int l=0;l<len;l+=d) {
int now=1;
for(int i=0;i<mid;i++) {
int x=t[l+i],y=(long long)now*t[l+mid+i]%mod;
t[l+i]=(x+y)%mod;
t[l+mid+i]=(x-y+mod)%mod;
now=(long long)now*unit%mod;
}
}
}
if(inv==-1) for(int i=0;i<len;i++) {
t[i]=(long long)t[i]*inver%mod;
}
}
inline void getinv(int *f,int *h,int deg) {
if(deg==1) {
h[0]=power(f[0],mod-2);
return ;
}
getinv(f,h,(deg+1)>>1);
while(p<(deg<<1)) {p<<=1;bit++;}
inver=power(p,mod-2);
for(int i=1;i<p;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
memcpy(c,f,deg*4);
ntt(h,p,1);ntt(c,p,1);
for(int i=0;i<p;i++) h[i]=h[i]*(2-(long long)c[i]*h[i]%mod+mod)%mod;
ntt(h,p,-1);
for(int i=deg;i<p;i++) h[i]=0;
}
inline void derivative(int *src,int *t,int len) {
for(int i=1;i<len;i++) {
t[i-1]=(long long)i*src[i]%mod;
}
t[len-1]=0;
}
inline void integrate(int *t,int len) {
for(int i=len-1;i;i--) {
t[i]=(long long)t[i-1]*power(i,mod-2)%mod;
}
t[0]=0;
}
inline void getln(int *src,int *f,int len) {
p=1,bit=0;
memset(f_inv,0,sizeof(f_inv)); //必须要清零,因为ntt会算上deg之后的系数,就会超出p的范围
memset(c,0,sizeof(c));
getinv(src,f_inv,len);
derivative(src,f,len);
ntt(f,p,1),ntt(f_inv,p,1);
for(int i=0;i<p;i++) f[i]=(long long)f[i]*f_inv[i]%mod;
ntt(f,p,-1);
integrate(f,len);
}
inline void getexp(int *f,int *h,int deg) {
if(deg==1) {
h[0]=1;
return ;
}
getexp(f,h,(deg+1)>>1);
memset(h_ln,0,sizeof(h_ln)); //清零,避免爆ntt
getln(h,h_ln,deg);
memcpy(c_e,f,deg*4);
ntt(c_e,p,1),ntt(h,p,1);ntt(h_ln,p,1);
for(int i=0;i<p;i++)
h[i]=h[i]*(1ll-h_ln[i]+c_e[i]+mod)%mod;
ntt(h,p,-1);
for(int i=deg;i<p;i++) h[i]=0;
}
int main() {
int n=read();
for(int i=0;i<n;i++) f[i]=read();
getexp(f,h,n);
for(int i=0;i<n;i++) printf("%d ",h[i]);
putchar('\n');
return 0;
}
// by pycr
- \(Important:\) 为什么代码中会有三个
memset
呢?我在 \(FFT\&NTT\) 的总结中也有提及,因为如果在运算的时候后面的系数不为 \(0\) 的话,乘出来的实际结果可能就会大于所预估的长度 \(p\)。实际上后面有没有系数在模意义下是不会影响结果的,错误的真正原因是因为把 \(NTT\) 乘爆了。后面的系数不会影响结果的前提是 \(NTT\) 能够得到正确的多项式。简而言之:如果原本乘出来的结果的最高次项为 \(x^{n-1}\),那么就一定至少要有 \(n\) 个点,而后面的系数则有可能导致实际的多项式会有更高次项,超出我们预估的点数。
多项式开根
- 求 \(B^2(x)\equiv A(x)\pmod {x^n}\)
推导
-
仍然是牛顿迭代。
-
设 \(G(B(x))=B^2(x)-A(x)\),则
\[\begin{aligned} B(x)&\equiv B_0(x)-\frac{G(B_0(x))}{G'(B_0(x))}&\pmod {x^n}\\ &\equiv B_0(x)-\frac{B_0^2(x)-A(x)}{2B_0(x)}&\pmod {x^n}\\ &\equiv \frac{B_0^2(x)+A(x)}{2B_0(x)}&\pmod {x^n} \end{aligned} \]结合多项式求逆元得解。
代码
#include <iostream>
#include <cstdio>
#include <cstring>
namespace IO {
const int N=1<<20;
char buf[N],*l=buf,*r=buf;
inline char gc() {
if(l==r) r=(l=buf)+fread(buf,1,N,stdin);
return l==r ? EOF : *(l++);
}
inline int read() {
int x=0,s=1;
char ch=gc();
while(!isdigit(ch)) {if(ch=='-') s=-1;ch=gc();}
while(isdigit(ch)) {x=x*10+(ch^48);ch=gc();}
return x*s;
}
}
using namespace std;
using IO::read;
const int maxn=4e5+10,mod=998244353,g=3,gn=(mod+1)/g,inv_2=(mod+1)/2;
int p=1,bit;
int f[maxn],h[maxn],c[maxn],rev[maxn];
int h_inv[maxn],c_r[maxn];
inline int power(long long a,int x) {
long long ans=1;
while(x) {
if(x&1) ans=(ans*a)%mod;
a=(a*a)%mod;
x>>=1;
}
return ans;
}
inline void ntt(int *t,int inv) {
for(int i=0;i<p;i++) {
if(i<rev[i]) swap(t[i],t[rev[i]]);
}
for(int mid=1;mid<p;mid<<=1) {
int unit=power(inv==1 ? g : gn,(mod-1)/(mid<<1));
int d=mid<<1;
for(int l=0;l<p;l+=d) {
int now=1;
for(int i=0;i<mid;i++) {
int x=t[l+i],y=(long long)now*t[l+mid+i]%mod;
t[l+i]=(x+y)%mod;
t[l+mid+i]=(x-y+mod)%mod;
now=(long long)now*unit%mod;
}
}
}
if(inv==-1) {
int inver=power(p,mod-2);
for(int i=0;i<p;i++) {
t[i]=(long long)t[i]*inver%mod;
}
}
}
inline void getinv(int *f,int *h,int deg) {
if(deg==1) {
h[0]=power(f[0],mod-2);
return ;
}
getinv(f,h,(deg+1)>>1);
while(p<(deg<<1)) {p<<=1;bit++;}
for(int i=1;i<p;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
memcpy(c,f,deg*4);
ntt(h,1),ntt(c,1);
for(int i=0;i<p;i++) h[i]=h[i]*(2-(long long)c[i]*h[i]%mod+mod)%mod;
ntt(h,-1);
for(int i=deg;i<p;i++) h[i]=0;
}
inline void getroot(int *f,int *h,int deg) {
if(deg==1) {
h[0]=1;
return ;
}
getroot(f,h,(deg+1)>>1);
p=1,bit=0;
memset(h_inv,0,sizeof(h_inv)); //清零
memset(c,0,sizeof(c));
getinv(h,h_inv,deg);
memcpy(c_r,f,deg*4);
ntt(h,1),ntt(c_r,1),ntt(h_inv,1);
for(int i=0;i<p;i++) h[i]=((long long)h[i]*h[i]%mod+c_r[i])*inv_2%mod*h_inv[i]%mod;
ntt(h,-1);
for(int i=deg;i<p;i++) h[i]=0;
}
int main() {
//#ifndef ONLINE_JUDGE
#ifdef LOCAL
freopen("c.in","r",stdin);
//freopen("c.out","w",stdout);
#endif
//ios::sync_with_stdio(false);
//cin.tie(0);cout.tie(0);
int n=read();
for(int i=0;i<n;i++) f[i]=read();
getroot(f,h,n);
for(int i=0;i<n;i++) printf("%d ",h[i]);
putchar('\n');
return 0;
}
// by pycr
- \(Tips:\) 和之前一样,每次都需要清零。
——2021年2月8日