ntt+cdq分治
原来zwh出的cf是斯特林
第二类斯特林数的定义是S(i,j)表示将i个物品分到j个无序集合的方案数,那么这道题中S(i,j)*j!*2^j是指将i个物品分到j个有序集合中并且每个集合可以选或不选的方案数,那么我们改变这个公式,得出
F[i]=∑F[j]*2*C(i,j),j=0-n,意思是第一个集合选n-j个的方案数,那么这个集合有两种情况选或不选,乘上2,再乘上选出元素的方案数。然后展开组合数,得出F[i]=∑F[j]*2*i!/(i-j)!/j!,移项得出F[i]/i!=∑F[j]/j!*2/(i-j)!
设新的函数G[i]=F[i]/i!,那么G[i]=∑G[j]*2/(i-j)!,后面是卷积形式,用ntt优化,又因为两边都有G,所以我们用cdq分治求和,复杂度nlog^2n
#include<bits/stdc++.h> using namespace std; typedef long long ll; const int N = (1 << 18) + 5, mod = 998244353; int n; ll ans; int rev[N]; ll a[N], b[N], fac[N], facinv[N], inv[N], f[N]; ll power(ll x, ll t) { ll ret = 1; for(; t; t >>= 1, x = x * x % mod) if(t & 1) ret = ret * x % mod; return ret; } void ntt(ll *a, int n, int f) { for(int i = 0; i < n; ++i) if(i < rev[i]) swap(a[i], a[rev[i]]); for(int m = 2; m <= n; m <<= 1) { int mid = (m >> 1); ll wn = power(3, f == 1 ? (mod - 1) / m : mod - 1 - (mod - 1) / m); for(int i = 0; i < n; i += m) { ll w = 1; for(int j = 0; j < mid; ++j) { ll u = a[i + j], v = a[i + j + mid] * w % mod; a[i + j] = (u + v) % mod; a[i + j + mid] = (u - v + mod) % mod; w = w * wn % mod; } } } if(f == -1) { ll inv = power(n, mod - 2); for(int i = 0; i < n; ++i) a[i] = a[i] * inv % mod; } } void cdq(int l, int r) { if(l == r) return; int mid = (l + r) >> 1; cdq(l, mid); int lim = r - l + 1, n = 1, k = 0; while(n < lim) { n <<= 1; ++k; } for(int i = 0; i < n; ++i) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (k - 1)); for(int i = 0; i < n; ++i) a[i] = b[i] = 0; for(int i = l; i <= mid; ++i) a[i - l] = f[i]; for(int i = 0; i < lim; ++i) b[i] = facinv[i]; ntt(a, n, 1); ntt(b, n, 1); for(int i = 0; i < n; ++i) a[i] = a[i] * b[i] % mod; ntt(a, n, -1); for(int i = mid + 1; i <= r; ++i) f[i] = (f[i] + 2 * a[i - l]) % mod; cdq(mid + 1, r); } int main() { scanf("%d", &n); fac[0] = inv[1] = facinv[0] = 1; for(int i = 1; i <= n; ++i) { fac[i] = fac[i - 1] * i % mod; if(i != 1) inv[i] = (mod - mod / i) * inv[mod % i] % mod; facinv[i] = facinv[i - 1] * inv[i] % mod; } f[0] = 1; cdq(0, n); for(int i = 0; i <= n; ++i) ans = (ans + f[i] * fac[i] % mod) % mod; printf("%lld ", ans); return 0; }