题目链接
(Solution)
先化简式子:
[f(n)=sum_{i=0}^nsum_{j=0}^iegin{Bmatrix} i \ j end {Bmatrix}*2^j*j!
]
[f(n)=sum_{j=0}^n2^j*j!sum_{i=0}^negin{Bmatrix} i \ j end {Bmatrix}
]
根据第二类斯特林数的公式:
[f(n)=sum_{j=0}^n2^j*j!sum_{i=0}^nsum_{k=0}^jfrac{(-1)^k}{k!}*frac{(j-k)^i}{(j-k)!}
]
[f(n)=sum_{j=0}^n2^j*j!sum_{k=0}^jfrac{(-1)^k}{k!}*frac{sum_{i=0}^n(j-k)^i}{(j-k)!}
]
我们令(F(i)=frac{(-1)^k}{i!},G(i)=frac{sum_{j=0}^ni^j}{i!})
根据等比数列求和公式可得:
(sum_{j=0}^ni^j=frac{i^{n+1}-1}{i-1})
所以(G(i)=frac{i^{n+1}-1}{(i-1)*i!})
(G(0)=1,G(1)=n+1)
那么
[f(n)=sum_{j=0}^n2^j*j!sum_{k=0}^jF(k)*G(j-k)
]
因为当(j>k)时(G(j-k)=0)
所以原式等价于:
[f(n)=sum_{j=0}^n2^j*j!sum_{k=0}^nF(k)*G(j-k)
]
这个东西直接(NTT)搞一搞就好了
(Code)
#include<bits/stdc++.h>
#define int long long
#define rg register
#define file(x) freopen(x".in","r",stdin);freopen(x".out","w",stdout);
using namespace std;
const int mod=998244353;
const int N=500010;
int read(){
int x=0,f=1;
char c=getchar();
while(c<'0'||c>'9') f=(c=='-')?-1:1,c=getchar();
while(c>='0'&&c<='9') x=x*10+c-48,c=getchar();
return f*x;
}
int f[N],g[N],r[N],limit=1,w[N],inv[N],a[N],b[N],jc[N];
int ksm(int a,int b){
int ans=1;
while(b){
if(b&1) ans=ans*a%mod;
a=a*a%mod,b>>=1;
}
return ans;
}
void ntt(int *a,int opt){
for(int i=0;i<=limit;i++)
if(i<r[i])
swap(a[i],a[r[i]]);
for(int i=1;i<limit;i<<=1){
int w=ksm(3,(mod-1)/(i*2));
if(opt==-1) w=ksm(w,mod-2);
for(int j=0;j<limit;j+=i<<1){
int l=1;
for(int k=j;k<j+i;k++){
int p=l*a[k+i]%mod;
a[k+i]=(a[k]-p+mod)%mod;
a[k]=(a[k]+p)%mod;
l=l*w%mod;
}
}
}
}
main(){
int n=read(),l=0,ans=0;
while(limit<=(n<<1))
limit<<=1,l++;
jc[0]=1;
for(int i=1;i<=limit;i++)
jc[i]=jc[i-1]*i%mod;
inv[limit]=ksm(jc[limit],mod-2);
for(int i=limit-1;i>=0;i--)
inv[i]=inv[i+1]*(i+1)%mod;
g[0]=1,g[1]=n+1,f[1]=mod-1,f[0]=1;
for(int i=2;i<=n;i++)
g[i]=(ksm(i,n+1)-1)%mod*inv[i]%mod*ksm(i-1,mod-2)%mod,f[i]=i&1?mod-inv[i]:inv[i];
for(int i=0;i<limit;i++)
r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
ntt(f,1),ntt(g,1);
for(int i=0;i<limit;i++)
f[i]=f[i]*g[i]%mod;
ntt(f,-1);
int inv=ksm(limit,mod-2);
for(int i=0;i<=n;i++)
ans=(ans+ksm(2,i)*jc[i]%mod*f[i]%mod*inv%mod)%mod;
printf("%lld",ans);
}