多项式开方学习笔记
前言:
今天学习了多项式开方,和多项式求逆挺像的,总结一下。
问题:
给定一个多项式(A(x)),求出多项式(B(x)),使(A(x) equiv B(x)^2 pmod{x^n})。
解析:
考虑递推求解,假设我们已经求出(B'(x)),使
[A(x) equiv B'(x)^2 pmod{x^{lceil frac{n}{2}
ceil}}
]
又:
[A(x) equiv B(x)^2 pmod{x^n}
]
所以:
[B(x)^2-B'(x)^2 equiv 0 pmod{x^{lceil frac{n}{2}
ceil}}
]
用平方差公式,有:
[(B(x)+B'(x))(B(x)-B'(x)) equiv 0 pmod{x^{lceil frac{n}{2}
ceil}}
]
取$$B(x)-B'(x) equiv 0 pmod{x^{lceil frac{n}{2} ceil}}$$
将式子两边平方,有:
[B(x)^2-2B(x)B'(x)+B'(x)^2 equiv 0 pmod{x^n}
]
那么我们就得到了递推式:
[B(x) equiv frac{A(x)+B'(x)^2}{2B'(x)} pmod{x^n}
]
多项式求逆即可。
最后当(n=1)时,求(B'(x))的常数项用二次剩余即可,这个可以看我的博客。
时间复杂度:
[T(n)=T(n/2)+O(nlogn),T(n)=O(nlogn)
]
代码实现
这是洛谷模板的代码。
#include<bits/stdc++.h>
#define N 300005
using namespace std;
inline int In(){
char c=getchar(); int x=0,ft=1;
for(;c<'0'||c>'9';c=getchar()) if(c=='-') ft=-1;
for(;c>='0'&&c<='9';c=getchar()) x=x*10+c-'0';
return x*ft;
}
const int P=998244353,g=3,inv_2=499122177;
int n,L,C,r[N],a[N],b[N],c[N],d[N],e[N];
inline int power(int x,int k){
if(!x) return 0;
int s=1,t=x;
for(;k;k>>=1,t=1ll*t*t%P) if(k&1) s=1ll*s*t%P;
return s;
}
inline void NTT_prepare(int x){
L=1; C=0; while(L<=x) L<<=1,++C;
for(int i=1;i<L;++i) r[i]=(r[i>>1]>>1)|((i&1)<<(C-1));
}
inline void NTT(int* A,int op){
for(int i=0;i<L;++i) if(i<r[i]) swap(A[i],A[r[i]]);
for(int i=1;i<L;i<<=1){
int Wn=power(g,(P-1)/(i<<1));
if(op==-1) Wn=power(Wn,P-2);
for(int j=0;j<L;j+=(i<<1)){
int w=1;
for(int k=0;k<i;++k,w=1ll*w*Wn%P){
int p=A[j+k],q=1ll*w*A[i+j+k]%P;
A[j+k]=(p+q)%P; A[i+j+k]=((p-q)%P+P)%P;
}
}
}
if(op==-1){
int inv_L=power(L,P-2);
for(int i=0;i<L;++i)
A[i]=1ll*inv_L*A[i]%P;
}
}
void Sol_inv(int k,int* A,int* B,int* C){
if(k==1){ B[0]=power(A[0],P-2); return; }
Sol_inv((k+1)/2,A,B,C); NTT_prepare(k<<1);
for(int i=0;i<k;++i) C[i]=A[i]; for(int i=k;i<L;++i) C[i]=0;
NTT(B,1); NTT(C,1);
for(int i=0;i<L;++i) B[i]=1ll*(2-1ll*B[i]*C[i]%P+P)%P*B[i]%P;
NTT(B,-1); for(int i=k;i<L;++i) B[i]=0;
}
void Sol_sqrt(int k,int* A,int* B,int* C,int* D){
if(k==1){ B[0]=1; return; }
Sol_sqrt((k+1)/2,A,B,C,D); NTT_prepare(k<<1);
for(int i=0;i<k;++i) C[i]=A[i]; for(int i=k;i<L;++i) C[i]=0;
for(int i=0;i<L;++i) D[i]=0;
Sol_inv(k,B,D,e);
NTT(B,1); NTT(C,1); NTT(D,1);
for(int i=0;i<L;++i) B[i]=1ll*(1ll*C[i]*D[i]%P+B[i])%P*inv_2%P;
NTT(B,-1); for(int i=k;i<L;++i) B[i]=0;
}
int main(){
n=In(); for(int i=0;i<n;++i) a[i]=In(); Sol_sqrt(n,a,b,c,d);
for(int i=0;i<n;++i) printf("%d%c",b[i],(i==n-1)?'
':' ');
return 0;
}