zoukankan      html  css  js  c++  java
  • 总结:多项式的运算

    多项式的小总结

    前置芝士:

    多项式的各种运算

    这些运算都是在模意义下进行的运算,但多项式的取模运算与整数的取模运算有些不同。

    多项式对 \(x^n\) 取模的意思是舍弃 \(x^n\) 以及更高次的部分。

    多项式求逆

    • 对于一个多项式 \(A(x)\) ,如果存在 \(B(x)\) 使得

    \[A(x)B(x)\equiv 1\pmod {x^n} \]

    • 那么称 \(B(x)\)\(A(x)\)\(mod\: x^n\) 意义下的逆元 \((inverse\:element)\),记作 \(A^{-1}(x)\)
    • 取模意义下,没有模数的逆元是没有意义的,因为不同的模数对应不一样的逆元。

    推导

    • 考虑用倍增法求解。

    • 假如我们现在已经求出了 \(A(x)\)\(mod\:x^{\frac{n}{2}}\) 意义下的逆元 \(B_0(x)\) ,即

      \[A(x)B_0(x)\equiv 1\pmod {x^{\frac{n}{2}}} \]

    • \[\because A(x)B(x)\equiv 1\pmod {x^{\frac{x}{2}}} \]

    • 两式相减并消去 \(A(x)\)

      \[\therefore B(x)-B_0(x)\equiv 0\pmod {x^{\frac{n}{2}}} \]

    • 再同时平方

      \[B^2(x)-2B(x)B_0(x)+B_0^2(x)\equiv 0\pmod {x^n}B^2(x)-2B(x)B_0(x)+B_0^2(x)\equiv 0\pmod {x^n} \]

    • 乘上 \(A(x)\),即可消去 \(B(x)\)

      \[B(x)-2B_0(x)+A(x)B_0^2\equiv 0\pmod {x^n} \]

    • 所以得到递推式

      \[B(x)=B_0(x)(2-A(x)B_0(x))\pmod {x^n} \]

    • 边界:当 \(n=1\) 时,\(B_0(x)\) 即为 \(A(x)\) 常数项的逆元。

    • 然后就可以在 \(O(nlogn)\) 的时间复杂度内求逆啦

    代码

    递归版:

    #include <iostream>			//递归
    #include <cstdio>
    
    using namespace std;
    
    const int maxn=4e5+10,mod=998244353,g=3,gn=332748118;
    
    int p=1,bit,inver;
    int f[maxn],h[maxn],c[maxn],rev[maxn];
    
    inline int power(long long a,int x) {
    	long long ans=1;
    	while(x) {
    		if(x&1) ans=(ans*a)%mod;
    		a=(a*a)%mod;
    		x>>=1;
    	}
    	return ans;
    }
    
    inline void ntt(int *t,int len,int inv) {
    	for(int i=0;i<len;i++) {
    		if(i<rev[i]) swap(t[i],t[rev[i]]);
    	}
    	for(int mid=1;mid<len;mid<<=1) {
    		int unit=power(inv==1 ? g : gn,(mod-1)/(mid<<1));
    		int d=mid<<1;
    		for(int l=0;l<len;l+=d) {
    			int now=1;
    			for(int i=0;i<mid;i++) {
    				int x=t[l+i],y=(long long)now*t[l+mid+i]%mod;
    				t[l+i]=(x+y)%mod;
    				t[l+mid+i]=(x-y+mod)%mod;
    				now=(long long)now*unit%mod;
    			}
    		}
    	}
    	if(inv==-1) for(int i=0;i<len;i++) {
    		t[i]=(long long)t[i]*inver%mod;
    	}
    }
    
    inline void solve(int deg) {
    	if(deg==1) {
    		h[0]=power(f[0],mod-2);
    		return ;
    	}
    	solve((deg+1)>>1);
    	while(p<(deg<<1)) {p<<=1;bit++;}
    	inver=power(p,mod-2);
    	for(int i=1;i<p;i++) {
    		rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
    	}
    	for(int i=0;i<deg;i++) c[i]=f[i];
    	ntt(c,p,1),ntt(h,p,1);
    	for(int i=0;i<p;i++) 
    		h[i]=h[i]*(2-(long long)c[i]*h[i]%mod+mod)%mod;
    	ntt(h,p,-1);
    	for(int i=deg;i<p;i++) h[i]=0;			//必须要归零
    }
    
    int main() {
    	int n=read();
    	for(int i=0;i<n;i++) f[i]=read();
    	solve(n);
    	for(int i=0;i<n;i++) printf("%d ",h[i]);
    	putchar('\n');
    	return 0;
    }
    
    // by pycr
    

    递推版:

    #include <iostream>			//递推
    #include <cstdio>
    
    using namespace std;
    
    const int maxn=4e5+10,mod=998244353,g=3,gn=332748118;
    
    int p=1,bit,inver;
    int f[maxn],h[maxn],t[maxn],rev[maxn];
    
    inline int power(long long a,int x) {
    	long long ans=1;
    	while(x) {
    		if(x&1) ans=(ans*a)%mod;
    		a=(a*a)%mod;
    		x>>=1;
    	}
    	return ans;
    }
    
    inline void ntt(int *t,int len,int inv) {
    	for(int i=0;i<len;i++) {
    		if(i<rev[i]) swap(t[i],t[rev[i]]);
    	}
    	for(int mid=1;mid<len;mid<<=1) {
    		int unit=power(inv==1 ? g : gn,(mod-1)/(mid<<1));
    		int d=mid<<1;
    		for(int l=0;l<len;l+=d) {
    			int now=1;
    			for(int i=0;i<mid;i++) {
    				int x=t[l+i],y=(long long)now*t[l+mid+i]%mod;
    				t[l+i]=(x+y)%mod;
    				t[l+mid+i]=(x-y+mod)%mod;
    				now=(long long)now*unit%mod;
    			}
    		}
    	}
    	if(inv==-1) for(int i=0;i<len;i++) {
    		t[i]=(long long)t[i]*inver%mod;
    	}
    }
    
    signed main() {
    	int n=read();
    	for(int i=0;i<n;i++) f[i]=read();
    	h[0]=power(f[0],mod-2);
    	for(int i=2;i<=(n<<1)-2;i<<=1) {
    		while(p<=2*i-3) {p<<=1,bit++;}
    		for(int j=1;j<p;j++) {
    			rev[j]=(rev[j>>1]>>1)|((j&1)<<(bit-1));
    		}
    		inver=power(p,mod-2);
    		for(int j=0;j<i;j++) t[j]=f[j];
    		ntt(t,p,1),ntt(h,p,1);
    		for(int j=0;j<p;j++) h[j]=h[j]*(2-(long long)t[j]*h[j]%mod+mod)%mod;
    		ntt(h,p,-1);
    		for(int j=i;j<p;j++) h[j]=0;
    	}
    	for(int i=0;i<n;i++) printf("%d ",h[i]);
    	putchar('\n');
    	return 0;
    }
    
    // by pycr
    
    • 测出来都在 \(900ms\) 左右,相差 \(1ms\)简直奇慢无比……
    • \(Tips:\) 每一次递归(递推)结束后,都需要把 \(h\) 数组清零,不然会影响答案的正确性。

    多项式对数函数

    • \(B(x)\equiv \ln\:A(x)\pmod {x^n}\)

    推导

    • \(\ln\) 看着太碍眼了,有没有什么能够消除 \(\ln\) 的方法?

    • 自然是有的,联系到我们之前学的微积分知识可以想到,用链规则对 \(\ln\:A(x)\) 求导可以得到 \(\frac{A'(x)}{A(x)}\) ,学过多项式的逆就很容易计算这个式子的答案了,最后对其积分就行。即:

      \[\ln\:A(x)=\int \frac{A'(x)}{A(x)}dx \]

    • \(Tips:\) 多项式常数项为 \(1\) 时才能取 \(\ln\) ,取后常数项为 \(0\)

    代码

    #include <iostream>
    #include <cstdio>
    
    using namespace std;
    
    const int maxn=4e5+10,mod=998244353,g=3,gn=(mod+1)/g;
    
    int p=1,bit,inver;
    int f[maxn],h[maxn],c[maxn],rev[maxn];
    
    inline int power(long long a,int x) {
    	long long ans=1;
    	while(x) {
    		if(x&1) ans=(ans*a)%mod;
    		a=(a*a)%mod;
    		x>>=1;
    	}
    	return ans;
    }
    
    inline void ntt(int *t,int len,int inv) {
    	for(int i=0;i<len;i++) {
    		if(i<rev[i]) swap(t[i],t[rev[i]]);
    	}
    	for(int mid=1;mid<len;mid<<=1) {
    		int unit=power(inv==1 ? g : gn,(mod-1)/(mid<<1));
    		int d=mid<<1;
    		for(int l=0;l<len;l+=d) {
    			int now=1;
    			for(int i=0;i<mid;i++) {
    				int x=t[l+i],y=(long long)now*t[l+mid+i]%mod;
    				t[l+i]=(x+y)%mod;
    				t[l+mid+i]=(x-y+mod)%mod;
    				now=(long long)now*unit%mod;
    			}
    		}
    	}
    	if(inv==-1) for(int i=0;i<len;i++) {
    		t[i]=(long long)t[i]*inver%mod;
    	}
    }
    
    inline void getinv(int *f,int *h,int deg) {
    	if(deg==1) {
    		h[0]=power(f[0],mod-2);
    		return ;
    	}
    	getinv(f,h,(deg+1)>>1);
    	while(p<(deg<<1)) {p<<=1;bit++;}
    	inver=power(p,mod-2);
    	for(int i=1;i<p;i++) 
    		rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
    	memcpy(c,f,deg*4);
    	ntt(h,p,1);ntt(c,p,1);
    	for(int i=0;i<p;i++) 
    		h[i]=h[i]*(2-(long long)c[i]*h[i]%mod+mod)%mod;
    	ntt(h,p,-1);
    	for(int i=deg;i<p;i++) h[i]=0;
    }
    
    inline void derivative(int *t,int len) {
    	for(int i=1;i<len;i++) {
    		t[i-1]=(long long)i*t[i]%mod;
    	}
    	t[len-1]=0;
    }
    
    inline void integrate(int *t,int len) {
    	for(int i=len-1;i;i--) {
    		t[i]=(long long)t[i-1]*power(i,mod-2)%mod;
    	}
    	t[0]=0;
    }
    
    int main() {
    	int n=read();
    	for(int i=0;i<n;i++) f[i]=read();
    	getinv(f,h,n);
    	derivative(f,n);
    	ntt(f,p,1),ntt(h,p,1);
    	for(int i=0;i<p;i++) f[i]=(long long)f[i]*h[i]%mod;
    	ntt(f,p,-1);
    	integrate(f,n);
    	for(int i=0;i<n;i++) printf("%d ",f[i]);
    	putchar('\n');
    	return 0;
    }
    
    // by pycr
    

    牛顿迭代

    ??怎么乱入啊? 牛顿迭代也是多项式运算中比较重要的一部分。

    • 多项式的牛顿迭代可不是用来在实数域和复数域上近似求解方程的。

    • 其用来求解以下方程中的 \(B(x)\)

      \[G(B(x))\equiv 0\pmod {x^n} \]

    • 还是考虑倍增法:假设我们已经求出了 \(\frac{n}{2}\) 次多项式 \(B_0(x)\) 使得:

      \[G(B_0(x))\equiv 0\pmod {x^{\frac{n}{2}}} \]

    • 结合之前泰勒展开的知识,将其在 \(B_0(x)\) 处泰勒展开:

      \[\sum_{i=0}^{+\infty}\frac{G^{(i)}(B_0(x))}{i!}(B(x)-B_0(x))^i\equiv 0\pmod {x^n} \]

      因为 \(B(x)-B_0(x)\)\(x^{\frac{n}{2}}\) 次项之下的系数都为 \(0\),所以其平方或者变成更高次幂之后在 \(mod\:x^n\) 意义下都为 \(0\),所以可以直接丢弃​​。

    • 那么原式就变为

      \[G(B(x))\equiv G(B_0(x))+G'(B_0(x))(B(x)-B_0(x))\pmod {x^n} \]

      因为 \(G(B(x))\equiv 0\pmod {x^n}\),得到

      \[B(x)\equiv B_0(x)-\frac{G(B_0(x))}{G'(B_0(x))}\pmod {x^n} \]

    • 然后就可以愉快的递归(递推)啦。

    考虑用牛顿迭代实现多项式求逆

    • 其实很简单

    • \(G(B(x))=\frac{1}{B(x)}-A(x)\equiv 0\pmod {x^n}\)

    • \[\begin{aligned} B(x)&\equiv B_0(x)-\frac{G(B_0(x))}{G'(B_0(x))}&\pmod {x^n}\\ &\equiv B_0(x)-\frac{\frac{1}{B_0(x)}-A(x)}{-\frac{1}{B_0^2(x)}}&\pmod {x^n}\\ &\equiv 2\cdot B_0(x)-B_0^2A(x)&\pmod {x^n}\\ &\equiv B_0(x)(2-B_0(x)A(x))&\pmod {x^n} \end{aligned} \]

    多项式指数函数

    • \(B(x)\equiv e^{A(x)}\pmod {x^n}\)

    推导

    • 这个需要用到牛顿迭代。不然我之前讲迭代干嘛?

    • 考虑对两边同时取自然对数:

      \[\ln B(x)\equiv A(x)\\ \]

    • 设函数 \(G(B(x))=\ln B(x)-A(x)\),套用牛顿迭代得:

      \[\begin{aligned} B(x)&\equiv B_0(x)-\frac{G(B_0(x))}{G'(B_0(x))}&\pmod {x^n}\\ &\equiv B_0(x)-\frac{\ln B_0(x)-A(x)}{\frac{1}{B_0(x)}}&\pmod {x^n}\\ &\equiv B_0(x)(1-\ln B_0(x)-A(x))&\pmod {x^n} \end{aligned} \]

      结合之前的多项式对数函数即可。

    • \(Tips:\) 多项式常数项为 \(0\) 时才能取 \(\exp\) ,取后常数项为 \(1\)

    代码

    #include <iostream>
    #include <cstdio>
    #include <cstring>
    
    using namespace std;
    
    const int maxn=4e5+10,mod=998244353,g=3,gn=(mod+1)/g;
    
    int p=1,bit,inver;
    int f[maxn],h[maxn],c[maxn],rev[maxn];
    int h_ln[maxn],c_e[maxn],f_inv[maxn];
    
    inline int power(long long a,int x) {
    	long long ans=1;
    	while(x) {
    		if(x&1) ans=(ans*a)%mod;
    		a=(a*a)%mod;
    		x>>=1;
    	}
    	return ans;
    }
    
    inline void ntt(int *t,int len,int inv) {
    	for(int i=0;i<len;i++) {
    		if(i<rev[i]) swap(t[i],t[rev[i]]);
    	}
    	for(int mid=1;mid<len;mid<<=1) {
    		int unit=power(inv==1 ? g : gn,(mod-1)/(mid<<1));
    		int d=mid<<1;
    		for(int l=0;l<len;l+=d) {
    			int now=1;
    			for(int i=0;i<mid;i++) {
    				int x=t[l+i],y=(long long)now*t[l+mid+i]%mod;
    				t[l+i]=(x+y)%mod;
    				t[l+mid+i]=(x-y+mod)%mod;
    				now=(long long)now*unit%mod;
    			}
    		}
    	}
    	if(inv==-1) for(int i=0;i<len;i++) {
    		t[i]=(long long)t[i]*inver%mod;
    	}
    }
    
    inline void getinv(int *f,int *h,int deg) {
    	if(deg==1) {
    		h[0]=power(f[0],mod-2);
    		return ;
    	}
    	getinv(f,h,(deg+1)>>1);
    	while(p<(deg<<1)) {p<<=1;bit++;}
    	inver=power(p,mod-2);
    	for(int i=1;i<p;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
    	memcpy(c,f,deg*4);
    	ntt(h,p,1);ntt(c,p,1);
    	for(int i=0;i<p;i++) h[i]=h[i]*(2-(long long)c[i]*h[i]%mod+mod)%mod;
    	ntt(h,p,-1);
    	for(int i=deg;i<p;i++) h[i]=0;
    }
    
    inline void derivative(int *src,int *t,int len) {
    	for(int i=1;i<len;i++) {
    		t[i-1]=(long long)i*src[i]%mod;
    	}
    	t[len-1]=0;
    }
    
    inline void integrate(int *t,int len) {
    	for(int i=len-1;i;i--) {
    		t[i]=(long long)t[i-1]*power(i,mod-2)%mod;
    	}
    	t[0]=0;
    }
    
    inline void getln(int *src,int *f,int len) {
    	p=1,bit=0;
    	memset(f_inv,0,sizeof(f_inv));			//必须要清零,因为ntt会算上deg之后的系数,就会超出p的范围
    	memset(c,0,sizeof(c));
    	getinv(src,f_inv,len);
    	derivative(src,f,len);
    	ntt(f,p,1),ntt(f_inv,p,1);
    	for(int i=0;i<p;i++) f[i]=(long long)f[i]*f_inv[i]%mod;
    	ntt(f,p,-1);
    	integrate(f,len);
    }
    
    inline void getexp(int *f,int *h,int deg) {
    	if(deg==1) {
    		h[0]=1;
    		return ;
    	}
    	getexp(f,h,(deg+1)>>1);
    	memset(h_ln,0,sizeof(h_ln));			//清零,避免爆ntt
    	getln(h,h_ln,deg);
    	memcpy(c_e,f,deg*4);
    	ntt(c_e,p,1),ntt(h,p,1);ntt(h_ln,p,1);
    	for(int i=0;i<p;i++) 
    		h[i]=h[i]*(1ll-h_ln[i]+c_e[i]+mod)%mod;
    	ntt(h,p,-1);
    	for(int i=deg;i<p;i++) h[i]=0;
    }
    
    int main() {
    	int n=read();
    	for(int i=0;i<n;i++) f[i]=read();
    	getexp(f,h,n);
    	for(int i=0;i<n;i++) printf("%d ",h[i]);
    	putchar('\n');
    	return 0;
    }
    
    // by pycr
    
    • \(Important:\) 为什么代码中会有三个 memset 呢?我在 \(FFT\&NTT\) 的总结中也有提及,因为如果在运算的时候后面的系数不为 \(0\) 的话,乘出来的实际结果可能就会大于所预估的长度 \(p\)。实际上后面有没有系数在模意义下是不会影响结果的,错误的真正原因是因为把 \(NTT\) 乘爆了。后面的系数不会影响结果的前提是 \(NTT\) 能够得到正确的多项式。简而言之:如果原本乘出来的结果的最高次项为 \(x^{n-1}\),那么就一定至少要有 \(n\) 个点,而后面的系数则有可能导致实际的多项式会有更高次项,超出我们预估的点数。

    多项式开根

    • \(B^2(x)\equiv A(x)\pmod {x^n}\)

    推导

    • 仍然是牛顿迭代。

    • \(G(B(x))=B^2(x)-A(x)\),则

      \[\begin{aligned} B(x)&\equiv B_0(x)-\frac{G(B_0(x))}{G'(B_0(x))}&\pmod {x^n}\\ &\equiv B_0(x)-\frac{B_0^2(x)-A(x)}{2B_0(x)}&\pmod {x^n}\\ &\equiv \frac{B_0^2(x)+A(x)}{2B_0(x)}&\pmod {x^n} \end{aligned} \]

      结合多项式求逆元得解。

    代码

    #include <iostream>
    #include <cstdio>
    #include <cstring>
    
    namespace IO {
    	const int N=1<<20;
    	char buf[N],*l=buf,*r=buf;
    	inline char gc() {
    		if(l==r) r=(l=buf)+fread(buf,1,N,stdin);
    		return l==r ? EOF : *(l++);
    	}
    	inline int read() {
    		int x=0,s=1;
    		char ch=gc();
    		while(!isdigit(ch)) {if(ch=='-') s=-1;ch=gc();}
    		while(isdigit(ch)) {x=x*10+(ch^48);ch=gc();}
    		return x*s;
    	}
    }
    
    using namespace std;
    using IO::read;
    
    const int maxn=4e5+10,mod=998244353,g=3,gn=(mod+1)/g,inv_2=(mod+1)/2;
    
    int p=1,bit;
    int f[maxn],h[maxn],c[maxn],rev[maxn];
    int h_inv[maxn],c_r[maxn];
    
    inline int power(long long a,int x) {
    	long long ans=1;
    	while(x) {
    		if(x&1) ans=(ans*a)%mod;
    		a=(a*a)%mod;
    		x>>=1;
    	}
    	return ans;
    }
    
    inline void ntt(int *t,int inv) {
    	for(int i=0;i<p;i++) {
    		if(i<rev[i]) swap(t[i],t[rev[i]]);
    	}
    	for(int mid=1;mid<p;mid<<=1) {
    		int unit=power(inv==1 ? g : gn,(mod-1)/(mid<<1));
    		int d=mid<<1;
    		for(int l=0;l<p;l+=d) {
    			int now=1;
    			for(int i=0;i<mid;i++) {
    				int x=t[l+i],y=(long long)now*t[l+mid+i]%mod;
    				t[l+i]=(x+y)%mod;
    				t[l+mid+i]=(x-y+mod)%mod;
    				now=(long long)now*unit%mod;
    			}
    		}
    	}
    	if(inv==-1) {
    		int inver=power(p,mod-2);
    		for(int i=0;i<p;i++) {
    			t[i]=(long long)t[i]*inver%mod;
    		}
    	}
    }
    
    inline void getinv(int *f,int *h,int deg) {
    	if(deg==1) {
    		h[0]=power(f[0],mod-2);
    		return ;
    	}
    	getinv(f,h,(deg+1)>>1);
    	while(p<(deg<<1)) {p<<=1;bit++;}
    	for(int i=1;i<p;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
    	memcpy(c,f,deg*4);
    	ntt(h,1),ntt(c,1);
    	for(int i=0;i<p;i++) h[i]=h[i]*(2-(long long)c[i]*h[i]%mod+mod)%mod;
    	ntt(h,-1);
    	for(int i=deg;i<p;i++) h[i]=0;
    }
    
    inline void getroot(int *f,int *h,int deg) {
    	if(deg==1) {
    		h[0]=1;
    		return ;
    	}
    	getroot(f,h,(deg+1)>>1);
    	p=1,bit=0;
    	memset(h_inv,0,sizeof(h_inv));			//清零
    	memset(c,0,sizeof(c));
    	getinv(h,h_inv,deg);
    	memcpy(c_r,f,deg*4);
    	ntt(h,1),ntt(c_r,1),ntt(h_inv,1);
    	for(int i=0;i<p;i++) h[i]=((long long)h[i]*h[i]%mod+c_r[i])*inv_2%mod*h_inv[i]%mod;
    	ntt(h,-1);
    	for(int i=deg;i<p;i++) h[i]=0;
    }
    
    int main() {
    //#ifndef ONLINE_JUDGE
    #ifdef LOCAL
    	freopen("c.in","r",stdin);
    	//freopen("c.out","w",stdout);
    #endif
    	//ios::sync_with_stdio(false);
    	//cin.tie(0);cout.tie(0);
    	int n=read();
    	for(int i=0;i<n;i++) f[i]=read();
    	getroot(f,h,n);
    	for(int i=0;i<n;i++) printf("%d ",h[i]);
    	putchar('\n');
    	return 0;
    }
    
    // by pycr
    
    • \(Tips:\) 和之前一样,每次都需要清零。

    ——2021年2月8日

    靡不有初,鲜克有终
  • 相关阅读:
    python学习笔记七--数据操作符
    ggplot2入门与进阶(下)
    ggplot2入门与进阶(上)
    ggplot2绘制Excel所有图
    机器学习中的数学-强大的矩阵奇异值分解(SVD)及其应用
    奇异值分解(SVD)原理详解及推导
    玩深度学习选哪块英伟达 GPU?有性价比排名还不够!
    深度学习主机攒机小记
    日志分析方法概述 & Web日志挖掘分析的方法
    python中matplotlib的颜色及线条控制
  • 原文地址:https://www.cnblogs.com/pycr/p/14391938.html
Copyright © 2011-2022 走看看