题目:https://www.luogu.org/problemnew/show/P5205
不会二次剩余。
牛顿迭代推开根式子:
( f^2(x)-g(x)=0 )
( f(x)=f_0(x)-frac{ f_0^2(x)-g(x) }{ ( f_0^2(x)-g(x) )' } = frac{ f_0^2(x)-g(x) }{ 2f_0(x) } )
实现的时候形如 ( f(x)=frac{ f_0(x)+frac{ g(x) }{ f_0(x) } }{2} )
用的 vector 。慢了很多。
#include<cstdio> #include<cstring> #include<algorithm> #include<vector> #define ll long long #define vi vector<int> #define pb push_back using namespace std; int rdn() { int ret=0;bool fx=1;char ch=getchar(); while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();} while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar(); return fx?ret:-ret; } const int N=(1<<18)+5,mod=998244353; int upt(int x){while(x>=mod)x-=mod;while(x<0)x+=mod;return x;} int pw(int x,int k) {int ret=1;while(k){if(k&1)ret=(ll)ret*x%mod;x=(ll)x*x%mod;k>>=1;}return ret;} int n; vi f; namespace Pl{ int len,r[N]; vi ntt(vi a,bool fx) { for(int i=0;i<len;i++) if(i<r[i])swap(a[i],a[r[i]]); for(int R=2;R<=len;R<<=1) { int wn=pw(3,fx?(mod-1)-(mod-1)/R:(mod-1)/R); for(int i=0,m=R>>1;i<len;i+=R) for(int j=0,w=1;j<m;j++,w=(ll)w*wn%mod) { int x=a[i+j], y=(ll)w*a[i+m+j]%mod; a[i+j]=upt(x+y); a[i+m+j]=upt(x-y); } } if(!fx)return a; int inv=pw(len,mod-2); for(int i=0;i<len;i++)a[i]=(ll)a[i]*inv%mod; return a; } vi inv(vi f,int n) { int tp; for(tp=1;tp<n;tp<<=1); tp<<=1; vi a,b; a.resize(tp); b.resize(tp); b[0]=pw(f[0],mod-2); for(int t=2,yt=1;yt<n;yt=t,t=len) { len=t<<1; for(int i=0,j=len>>1;i<len;i++) r[i]=(r[i>>1]>>1)+((i&1)?j:0); for(int i=0;i<t;i++)a[i]=f[i]; a=ntt(a,0); b=ntt(b,0); for(int i=0;i<len;i++) b[i]=upt((ll)b[i]*(2-(ll)a[i]*b[i]%mod)%mod); b=ntt(b,1); for(int i=t;i<len;i++)a[i]=0; for(int i=t;i<len;i++)b[i]=0; } for(int i=n;i<len;i++)b[i]=0;// return b; } vi sqr(vi f,int n) { int iv2=pw(2,mod-2); int tp; for(tp=1;tp<n;tp<<=1); tp<<=1; vi a,b,c; a.resize(tp); b.resize(tp); b[0]=1; for(int t=2,yt=1;yt<n;yt=t,t=len) { for(int i=0;i<t;i++)a[i]=f[i]; c=inv(b,t); len=t<<1; c.resize(len);//resize for(int i=0,j=len>>1;i<len;i++) r[i]=(r[i>>1]>>1)+((i&1)?j:0); a=ntt(a,0); c=ntt(c,0); for(int i=0;i<len;i++) a[i]=(ll)a[i]*c[i]%mod; a=ntt(a,1); for(int i=0;i<t;i++)b[i]=(ll)(b[i]+a[i])*iv2%mod; for(int i=t;i<len;i++)a[i]=0; } return b; } } int main() { n=rdn()-1; int tp; for(tp=1;tp<=n;tp<<=1); tp<<=1; f.resize(tp); for(int i=0;i<=n;i++)f[i]=rdn(); f=Pl::sqr(f,n+1); for(int i=0;i<=n;i++)printf("%d ",f[i]);puts(""); return 0; }