Description
给定一个 (n-1) 次多项式 (A(x)),求一个在(mod x^n)意义下的多项式 (B(x)),使得 (B^2(x) equiv A(x)(mod x^n))
多项式的系数在 (mod 998244353) 的意义下进行运算。
(n leq 10^5,a_i in [0,998244352] cap mathbb{Z})
Solution
其实推导过程和多项式求逆类似
考虑倍增
假设我们已经求出了一个多项式 (G(x)) 使得 (G^2(x) equiv A(x) (mod x^{lceil frac{n}{2} ceil})) ,而 (B(x)) 本来就有 (B^2(x) equiv A(x) (mod x^{lceil frac{n}{2} ceil})) ,那么
平方差公式展开
在这里我们需要说一下究竟取哪个,又有什么区别的问题
假设题目要求的最终的答案为 (F(x))
因为是在模大质数意义下进行的运算,所以要么有 (B(x)equiv G(x)(mod x^{lceil frac{n}{2} ceil})) ,要么有 (B(x)+G(x)equiv 0(mod x^{lceil frac{n}{2} ceil})) ,至于为什么只需要关注一下 (0) 次项的系数就可以了
若我们在倍增的过程中全部选择 (B(x)equiv G(x)(mod x^{lceil frac{n}{2} ceil})) 或选择了偶数次 (B(x)+G(x)equiv 0(mod x^{lceil frac{n}{2} ceil})) ,那么最后得到的答案就是 (F(x)),反之若我们选择了奇数次 (B(x)+G(x)equiv 0(mod x^{lceil frac{n}{2} ceil})) ,那么最后得到的答案就是 (-F(x)) ,原因在下面的推导中不难看出。所以 (sqrt{A(x)}) 有两解,为 (pm F(x))
我们选择前者,即
移项后平方展开得到
即
移项得
然后除过去,得
多项式求逆+( ext{NTT})即可
#include<cstdio>
#include<iostream>
using namespace std;
const int N=1e5+10;
const int mod=998244353;
const int g=3;
const int invg=332748118;
int n,a[N<<2],b[N<<2],c[N<<2],d[N<<2],f[N<<2],h[N<<2],k;
inline void Add(int &x,int y){x+=y;x-=x>=mod? mod:0;}
inline int MOD(int x){x-=x>=mod? mod:0;return x;}
inline int fas(int x,int p){int res=1;while(p){if(p&1)res=1ll*res*x%mod;p>>=1;x=1ll*x*x%mod;}return res;}
inline void NTT(int *a,int f){
for(register int i=0,j=0;i<k;i++){
if(i>j)swap(a[i],a[j]);
for(register int l=k>>1;(j^=l)<l;l>>=1);}
for(register int i=1;i<k;i<<=1){
int w=fas(~f? g:invg,(mod-1)/(i<<1));
for(register int j=0;j<k;j+=(i<<1)){
int e=1;
for(register int p=0;p<i;p++,e=1ll*e*w%mod){
int x=a[j+p],y=1ll*a[j+p+i]*e%mod;
a[j+p]=MOD(x+y);a[j+p+i]=MOD(x-y+mod);
}
}
}
}
inline void PINV(int *a,int *b,int deg){
if(deg==1){b[0]=fas(a[0],mod-2);return;}
int M=(deg+1)>>1;PINV(a,b,M);
k=1;while(k<=deg+deg-2)k<<=1;int INV=fas(k,mod-2);
for(register int i=0;i<deg;i++)h[i]=a[i];
for(register int i=deg;i<k;i++)h[i]=0;
NTT(h,1);NTT(b,1);
for(register int i=0;i<k;i++)
b[i]=(2ll-1ll*h[i]*b[i]%mod+mod)*b[i]%mod;
NTT(b,-1);
for(register int i=0;i<deg;i++)b[i]=1ll*b[i]*INV%mod;
for(register int i=deg;i<k;i++)b[i]=0;
}
inline void Sqrt(int *a,int *b,int deg){
if(deg==1){b[0]=1;return;}
int M=(deg+1)>>1;Sqrt(a,b,M);
k=1;while(k<=deg+deg-2)k<<=1;int INV=fas(k,mod-2);
for(register int i=0;i<deg;i++)c[i]=b[i];
for(register int i=deg;i<k;i++)c[i]=0;
NTT(c,1);
for(register int i=0;i<k;i++)c[i]=1ll*c[i]*c[i]%mod;
NTT(c,-1);
for(register int i=0;i<deg;i++)c[i]=1ll*c[i]*INV%mod;
for(register int i=deg;i<k;i++)c[i]=0;
for(register int i=0;i<deg;i++)Add(c[i],a[i]);
for(register int i=0;i<deg;i++)d[i]=MOD(b[i]+b[i]);
for(register int i=deg;i<k;i++)d[i]=0;
for(register int i=0;i<k;i++)f[i]=0;
PINV(d,f,deg);
k=1;while(k<=deg+deg-2)k<<=1;
NTT(f,1);NTT(c,1);
for(register int i=0;i<k;i++)b[i]=1ll*f[i]*c[i]%mod;
NTT(b,-1);
for(register int i=0;i<deg;i++)b[i]=1ll*b[i]*INV%mod;
for(register int i=deg;i<k;i++)b[i]=0;
}
int main(){
scanf("%d",&n);n--;
for(register int i=0;i<=n;i++)scanf("%d",&a[i]);
Sqrt(a,b,n+1);
for(register int i=0;i<=n;i++)printf("%d%c",b[i],i==n? '
':' ');
return 0;
}