思路
调了半天发现ln忘了清空数组了。。。
就是这个式子
[A^k(x) equiv e^{k{ln (A(x)) }}
]
代码
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
const int MAXN = 300000;
const int G = 3;
const int invG = 332748118;
const int MOD = 998244353;
int rev[MAXN],inv_val[MAXN];
int pow(int a,int b){
int ans=1;
while(b){
if(b&1)
ans=(1LL*ans*a)%MOD;
a=(1LL*a*a)%MOD;
b>>=1;
}
return ans;
}
void cal_rev(int *rev,int n,int lim){
for(int i=0;i<n;++i)
rev[i]=(rev[i>>1]>>1)|((i&1)<<(lim-1));
}
void NTT(int *a,int opt,int n,int lim){
for(int i=0;i<n;++i)
if(i<rev[i])
swap(a[i],a[rev[i]]);
for(int i=2;i<=n;i<<=1){
int len=i/2;
int tmp=pow((opt)?G:invG,(MOD-1)/i);
for(int j=0;j<n;j+=i){
int arr=1;
for(int k=j;k<j+len;k++){
int t=(1LL*a[k+len]*arr)%MOD;
a[k+len]=(a[k]-t+MOD)%MOD;
a[k]=(a[k]+t)%MOD;
arr=(1LL*arr*tmp)%MOD;
}
}
}
if(!opt){
int invN=pow(n,MOD-2);
for(int i=0;i<n;++i)
a[i]=(1LL*a[i]*invN)%MOD;
}
}
void mul(int *a,int *b,int &at,int bt){
static int tmp1[MAXN];
int num=(at+bt),n=1,lim=0;
while(n<=(num+2))
n<<=1,lim++;
for(int i=0;i<n;++i)
tmp1[i]=b[i];
cal_rev(rev,n,lim);
NTT(a,1,n,lim);
NTT(tmp1,1,n,lim);
for(int i=0;i<n;++i)
a[i]=(1LL*a[i]*tmp1[i])%MOD;
NTT(a,0,n,lim);
at=num;
}
void inv(int *a,int *b,int dep,int &midlen,int &midlim){
if(dep==1){
b[0]=pow(a[0],MOD-2);
return;
}
inv(a,b,(dep+1)>>1,midlen,midlim);
static int tmp[MAXN];
while((dep<<1)>midlen)
midlen<<=1,midlim++;
for(int i=0;i<dep;++i)
tmp[i]=a[i];
for(int i=dep;i<midlen;++i)
tmp[i]=0;
cal_rev(rev,midlen,midlim);
NTT(tmp,1,midlen,midlim);
NTT(b,1,midlen,midlim);
for(int i=0;i<midlen;++i)
b[i]=1LL*b[i]*(2-1LL*tmp[i]*b[i]%MOD+MOD)%MOD;
NTT(b,0,midlen,midlim);
for(int i=dep;i<midlen;++i)
b[i]=0;
}
void qd(int *a,int &at){
for(int i=0;i<at;++i)
a[i]=(1LL*a[i+1]*(i+1))%MOD;
a[at]=0;
at--;
}
void jf(int *a,int &at){
at++;
for(int i=at;i>=1;i--)
a[i]=(1LL*a[i-1]*inv_val[i])%MOD;
a[0]=0;
}
void ln(int *a,int *b,int &at){
static int tmp[MAXN];
int midlen=1,midlim=0,tmpt=at,bt=at;
for(int i=0;i<=at;++i)
tmp[i]=a[i];
inv(a,b,at+1,midlen,midlim);
qd(tmp,tmpt);
mul(b,tmp,at,tmpt);
jf(b,tmpt);
for(int i=0;i<=bt;i++)
tmp[i]=0;
for(int i=bt+1;i<=at;++i)
tmp[i]=0,b[i]=0;
at=bt;
}
void exp(int *a,int *b,int dep){
if(dep==1){
b[0]=1;
return;
}
exp(a,b,(dep+1)>>1);
static int tmp1[MAXN];
for(int i=0;i<dep;++i)
tmp1[i]=0;
ln(b,tmp1,dep);
for(int i=0;i<dep;++i)
tmp1[i]=(a[i]-tmp1[i]+MOD)%MOD;
tmp1[0]+=1;
int midlen=dep-1;
mul(b,tmp1,midlen,dep-1);
for(int i=dep;i<midlen;++i)
b[i]=0;
}
void init_inv(int n){
inv_val[0]=0;
inv_val[1]=1;
for(int i=2;i<=n;i++)
inv_val[i]=1LL*(MOD-MOD/i)*inv_val[MOD%i]%MOD;
}
void mul(int *a,int n,int k){
for(int i=0;i<=n;i++)
a[i]=(1LL*a[i]*k)%MOD;
}
void pow(int *a,int *b,int n,int k){
static int tmp[MAXN];
int t=n;
ln(a,tmp,t);
mul(tmp,n,k);
exp(tmp,b,n);
}
int a[MAXN],b[MAXN],n,k;
int main(){
scanf("%d",&n);
init_inv(n+10);
char c=getchar();
while(c<'0'||c>'9')
c=getchar();
while(c>='0'&&c<='9'){
k=(1LL*k*10%MOD+c-'0')%MOD;
c=getchar();
}
for(int i=0;i<n;++i)
scanf("%d",&a[i]);
pow(a,b,n,k);
for(int i=0;i<n;++i)
printf("%d ",b[i]);
return 0;
}