Description
在2016年,佳媛姐姐刚刚学习了第二类斯特林数,非常开心。
现在他想计算这样一个函数的值:
S(i, j)表示第二类斯特林数,递推公式为:
S(i, j) = j ∗ S(i − 1, j) + S(i − 1, j − 1), 1 <= j <= i − 1。
边界条件为:S(i, i) = 1(0 <= i), S(i, 0) = 0(1 <= i)
你能帮帮他吗?
Input
输入只有一个正整数
Output
输出f(n)。由于结果会很大,输出f(n)对998244353(7 × 17 × 223 + 1)取模的结果即可。1 ≤ n ≤ 100000
Sample Input
3
Sample Output
87
Sol
本题递推公式没用,因为我们需要在(nlogn)时间内求出这个结果。
首先我们根据第二类斯特林数的定义“把i个数字分到j的相同集合的方案数”,得:
(s(i,j)=frac{1}{j!}*sum_{k=0}^{j}(-1)^kC^k_j(j-k)^i)
意义就是我们假设至少k个集合为空,然后用组合数算出选择空集的方案数,之后剩下(j-k)个集合,i个数字,那么方案数就是((j-k)^i)。但是这样不保证剩下的严格非空,所以要容斥一波。然后因为上述是有序的,而第二类斯特林数是无序的,所以要除以阶乘。
之后我们把组合数拆开,得到:
(s(i,j)=sum_{k=0}^{j}frac{(-1)^k}{k!}*frac{(j-k)^i}{(j-k)!})
带入所求的式子中,得到:
(f(n)=sum_{i=0}^{n}sum_{j=0}^{i}sum_{k=0}^{j}frac{(-1)^k}{k!}*frac{(j-k)^i}{(j-k)!})
因为(j>i)的时候右面不会产生贡献,所以我们可以把j的范围写到n:
(f(n)=sum_{i=0}^{n}sum_{j=0}^{n}j!2^jsum_{k=0}^{j}frac{(-1)^k}{k!}*frac{(j-k)^i}{(j-k)!})
尽管是卷积,但是还是得(O(n^2logn)),但是i只用到了一处,所以我们更换循环顺序,得:
(f(i)=sum_{j=0}^{n}j!2^jsum_{k=0}^{j}frac{(-1)^k}{k!}*frac{sum_{i=0}^{n}(j-k)^i}{(j-k)!})
那个带(sum)的是个等比数列,可以(O(1))计算,然后就是(NTT)啦。
时间复杂度(O(nlogn))。
Code
#include <cstdio>
int i,j,k,I[100005],IF[100005],a[262145],b[262145],F[262145],B[100005],P=998244353,w,wn,t,n,len,ans;
int ksm(int a,int b){int res=1;for(;b;b>>=1,a=1ll*a*a%P) if(b&1) res=1ll*res*a%P;return res;}
void ntt(int *a,int n,int op)
{
for(i=k=0;i<n;i++){if(i>k) a[i]^=a[k]^=a[i]^=a[k];for(j=(n>>1);(k^=j)<j;j>>=1);}
for(k=2,wn=ksm(3,op==1?(P-1)/k:P-1-(P-1)/k);k<=n;k<<=1,wn=ksm(3,op==1?(P-1)/k:P-1-(P-1)/k))
for(i=0,w=1;i<n;i+=k,w=1) for(j=0;j<(k>>1);j++,w=1ll*w*wn%P)
t=1ll*a[i+j+(k>>1)]*w%P,a[i+j+(k>>1)]=(a[i+j]-t+P)%P,a[i+j]=(a[i+j]+t)%P;
if(op==-1) for(t=ksm(n,P-2),i=0;i<n;i++) a[i]=1ll*a[i]*t%P;
}
int main()
{
scanf("%d",&n);for(len=1;len<=n*2;len<<=1);
F[0]=B[0]=I[1]=F[1]=1,B[1]=2;for(int i=2;i<=n;i++) F[i]=1ll*F[i-1]*i%P,B[i]=1ll*B[i-1]*2%P,I[i]=1ll*I[P%i]*(P-P/i)%P;
IF[n]=ksm(F[n],P-2);for(int i=n-1;~i;i--) IF[i]=1ll*IF[i+1]*(i+1)%P;
a[0]=b[0]=1;for(int i=1;i<=n;i++) b[i]=(((i&1)?-1:1)*IF[i]+P)%P,a[i]=(i==1)?n+1:1ll*IF[i]*(ksm(i,n+1)-1)%P*I[i-1]%P;
ntt(a,len,1);ntt(b,len,1);
for(int i=0;i<len;i++) a[i]=1ll*a[i]*b[i]%P;ntt(a,len,-1);
for(int i=0;i<=n;i++) ans=(ans+1ll*B[i]*F[i]%P*a[i]%P)%P;
printf("%d
",ans);
}