$f(n)=sumlimits_{i=0}^{n} sumlimits_{j=0}^{i} S(i,j) imes 2^j imes j!$
其中$S(i,j)$为第二类斯特林数,公式为$S(i,j)=frac{1}{j!} sumlimits_{k=0}^{j} (-1)^k C(j,k) (j-k)^i$
求$f(n)$,$n<=100000$,答案对$998244353(=2^{23} imes 7 imes 17 + 1)$取模
$f(n)=sumlimits_{i=0}^{n} sumlimits_{j=0}^{i} 2^j imes sumlimits_{k=0}^{j} (-1)^k imes frac{j!}{k! imes (j-k)!} imes (j-k)^i$
$=sumlimits_{i=0}^{n} sumlimits_{j=0}^{i} 2^j imes j! imes sumlimits_{k=0}^{j} frac{(j-k)^i}{(j-k)!} imes frac{(-1)^k}{k!}$
$=sumlimits_{j=0}^{n} 2^j imes j! imes sumlimits_{k=0}^{j} frac{sumlimits_{i=0}^{n}(j-k)^i}{(j-k)!} imes frac{(-1)^k}{k!}$
可以发现,$sumlimits_{i=0}^{n}(j-k)^i$项就是一个等比数列求和,可以快速幂求出。
那么两个分数分别只与j-k和k有关了,相乘的话,就是卷积形式FFT求出,枚举最外层j即可。
Update10/04:
终于抽出时间码完啦,少打了一个等号调了半天~
1 #include<cstdio> 2 #define mod 998244353 3 #define int long long 4 int rev[400005],bin=1,n,fac[100005],inv[100005],invv[100005],INV,sumpw[100005]; 5 int a[400005],b[400005],sum; 6 int pow(int b,int t,int a=1){for(;t;t>>=1,b=b*b%mod)if(t&1)a=a*b%mod;return a;} 7 void NTT(int *a,int opt){ 8 for(int i=1;i<bin;++i)if(i<rev[i])a[i]^=a[rev[i]]^=a[i]^=a[rev[i]]; 9 for(int mid=1,wn=pow(3,mod-1>>1);mid<bin;mid<<=1,wn=pow(3,(mod-1)/2/mid*opt+mod-1)) 10 for(int i=0;i<bin;i+=mid<<1) 11 for(int j=0,w=1;j<mid;++j,w=w*wn%mod){ 12 int x=a[i+j],y=a[i+j+mid]*w%mod; 13 a[i+j]=(x+y)%mod;a[i+j+mid]=(mod+x-y)%mod; 14 } 15 if(opt==-1)for(int i=0;i<bin;++i)a[i]=a[i]*INV%mod; 16 } 17 main(){ 18 scanf("%lld",&n); 19 while(bin<=n<<1)bin<<=1;//printf("%lld ",bin); 20 for(int i=1;i<bin;++i)rev[i]=rev[i>>1]>>1|(i&1)*bin>>1; 21 INV=pow(bin,mod-2); 22 fac[1]=inv[1]=invv[1]=fac[0]=inv[0]=sumpw[0]=1; 23 for(int i=2;i<=n;++i)fac[i]=fac[i-1]*i%mod,invv[i]=-mod/i*invv[mod%i]%mod+mod,inv[i]=inv[i-1]*invv[i]%mod; 24 sumpw[1]=n+1;for(int i=2;i<=n;++i)sumpw[i]=(pow(i,n+1)-1)*invv[i-1]%mod; 25 for(int i=0;i<=n;++i)a[i]=sumpw[i]*inv[i]%mod,b[i]=pow(mod-1,i)*inv[i]%mod;//,printf("%lld %lld ",a[i],b[i]); 26 NTT(a,1);NTT(b,1); 27 for(int i=0;i<bin;++i)a[i]=a[i]*b[i]%mod; 28 NTT(a,-1);//for(int i=0;i<bin;++i)printf("%lld ",a[i]); 29 for(int j=0;j<=n;++j)sum=(sum+pow(2,j)*fac[j]%mod*a[j])%mod; 30 printf("%lld ",sum); 31 }