多校训练8,有官方题解
主要之前没写过ntt,感觉不是很懂原根
先贴一份当模板吧
1 #include<iostream> 2 #include<cstdio> 3 #include<cmath> 4 #include<cstring> 5 #include<algorithm> 6 typedef long long ll; 7 using namespace std; 8 const int mo=998244353; 9 const int g=3; 10 int e[100010],a[100000*4],b[100010*4],jc[100010],d[100010],ni[100010],r[100010*4]; 11 int w[30][2],ans[100010],rg[30]; 12 int cas,m,n; 13 bool cmp(int a,int b) 14 { 15 return a>b; 16 } 17 18 ll quick(ll x,int y) 19 { 20 ll s=1; 21 while (y) 22 { 23 if (y&1) s=s*x%mo; 24 x=x*x%mo; y>>=1; 25 } 26 return s; 27 } 28 void ntt(int *a, int f) 29 { 30 for (int i=0; i<n; i++) 31 if (i<r[i]) swap(a[i],a[r[i]]); 32 int now=0; 33 for (int i=1; i<n; i<<=1) 34 { 35 int p=w[++now][f]; 36 for (int j=0; j<n; j+=i<<1) 37 { 38 int w=1; 39 for (int k=0; k<i; k++) 40 { 41 int u=a[k+j],v=1ll*w*a[j+k+i]%mo; 42 a[j+k]=(u+v)%mo; 43 a[j+k+i]=(u-v+mo)%mo; 44 w=1ll*w*p%mo; 45 } 46 } 47 } 48 } 49 50 int main() 51 { 52 freopen("1.in","r",stdin); 53 int now=(mo-1)/2,ng=quick(g,mo-2),l=0; 54 while (now%2==0) 55 { 56 w[++l][1]=quick(g,now); 57 w[l][0]=quick(ng,now); 58 rg[l]=quick(1<<l,mo-2); 59 now>>=1; 60 } 61 jc[0]=d[0]=ni[0]=1; 62 for (int i=1; i<=100000; i++) 63 { 64 jc[i]=1ll*jc[i-1]*i%mo; 65 ni[i]=quick(jc[i],mo-2); 66 d[i]=2*d[i-1]%mo; 67 } 68 scanf("%d",&cas); 69 while (cas--) 70 { 71 scanf("%d",&m); 72 for (int i=1; i<=m; i++) 73 scanf("%d",&e[i]); 74 sort(e+1,e+1+m,cmp); 75 memset(a,0,sizeof(a)); 76 memset(b,0,sizeof(b)); 77 for (int i=1; i<=m; i++) 78 a[i]=1ll*jc[i-1]*d[m-i]%mo*e[i]%mo; 79 for (int i=0; i<=m; i++) 80 b[m-i]=ni[i]; 81 /* for (int i=m+1; i<2*m+1; i++) 82 { 83 int s=0; 84 for (int j=0; j<=i; j++) s=(s+1ll*a[j]*b[i-j]%mo)%mo; 85 cout <<1ll*s*ni[i-m-1]%mo<<" "; 86 } 87 cout <<endl; 88 break;*/ 89 int l=0; 90 for (n=1; n<=2*m+1; n<<=1) l++; 91 for (int i=0; i<n; i++) r[i]=(r[i>>1]>>1)|((i&1)<<(l-1)); 92 ntt(a,1); ntt(b,1); 93 for (int i=0; i<n; i++) a[i]=1ll*a[i]*b[i]%mo; 94 ntt(a,0); 95 for (int i=m+1; i<2*m+1; i++) 96 ans[i-m]=1ll*a[i]*ni[i-m-1]%mo*rg[l]%mo; 97 for (int i=1; i<=m; i++) 98 ans[i]=(ans[i-1]+ans[i])%mo; 99 for (int i=1; i<=m; i++) 100 printf("%d ",ans[i]); 101 cout <<endl; 102 } 103 return 0; 104 }