Solution
- \(O(n^2)\) 做法不会的先去看这个
- 这里只讲如何快速求第一类斯特林数 \(s(n,m)\)
- 首先有递推式:\(s(i,j)=s(i-1,j-1)+(i-1)*s(i-1,j)\)
- 为方便卷积写成这样(第二维和为 \(j\)):\(s(i,j)=s(i-1,j-1)*b(i,1)+b(i,0)*s(i-1,j)\)
- 其中 \(b(i,1)=1,b(i,0)=i-1\)
- 那么把 \(s(i)\) 看成一个多项式,\(s(i,j)\) 为这个多项式 \(x^j\) 项的系数,初值:\(s(0,0)=1\)
- \(b(i)\) 同理
- 那么 \(s(i)=s(i-1)*b(i)\)
- 于是把 \(s(0)\) ~ \(s(n)\) 都乘起来,得到的多项式就是 \(s(n)\)
- 这个多项式的 \(x^i\) 项的系数就是 \(s(n,i)\)
- 分治 \(ntt\) 即可,时间复杂度 \(O(n \log^2n)\)
code
#include <bits/stdc++.h>
using namespace std;
#define ll long long
const int e = 1e6 + 5, mod = 998244353;
int n, a1, b1, fac[e], inv[e], rev[e], lim;
vector<int>g[e];
inline int ksm(int x, int y)
{
int res = 1;
while (y)
{
if (y & 1) res = (ll)res * x % mod;
y >>= 1;
x = (ll)x * x % mod;
}
return res;
}
inline void upt(int &x, int y)
{
x = y;
if (x >= mod) x -= mod;
}
inline void fft(int n, int *a, int opt)
{
int i, j, k, r = (opt == 1 ? 3 : (mod + 1) / 3);
for (i = 0; i < n; i++)
if (i < rev[i]) swap(a[i], a[rev[i]]);
for (k = 1; k < n; k <<= 1)
{
int w0 = ksm(r, (mod - 1) / (k << 1));
for (i = 0; i < n; i += (k << 1))
{
int w = 1;
for (j = 0; j < k; j++)
{
int b = a[i + j], c = (ll)w * a[i + j + k] % mod;
upt(a[i + j], b + c);
upt(a[i + j + k], b + mod - c);
w = (ll)w * w0 % mod;
}
}
}
}
inline void solve(int l, int r)
{
if (l >= r) return;
int i, mid = l + r >> 1;
solve(l, mid);
solve(mid + 1, r);
static int a[266666], b[266666], c[266666];
int k = 0, la = g[l].size(), lb = g[mid + 1].size();
lim = 1;
while (lim < la + lb - 1)
{
lim <<= 1;
k++;
}
for (i = 0; i < lim; i++)
{
a[i] = b[i] = 0;
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << k - 1);
}
for (i = 0; i < la; i++) a[i] = g[l][i];
for (i = 0; i < lb; i++) b[i] = g[mid + 1][i];
fft(lim, a, 1);
fft(lim, b, 1);
for (i = 0; i < lim; i++) a[i] = (ll)a[i] * b[i] % mod;
fft(lim, a, -1);
int tot = ksm(lim, mod - 2);
for (i = 0; i < lim; i++) a[i] = (ll)a[i] * tot % mod;
g[l].clear();
for (i = 0; i < la + lb - 1; i++) g[l].push_back(a[i]);
}
inline int c(int x, int y)
{
if (x < y) return 0;
return (ll)fac[x] * inv[y] % mod * inv[x - y] % mod;
}
int main()
{
int i;
cin >> n >> a1 >> b1;
fac[0] = 1;
for (i = 1; i <= n; i++) fac[i] = (ll)fac[i - 1] * i % mod;
inv[n] = ksm(fac[n], mod - 2);
for (i = n - 1; i >= 0; i--) inv[i] = (ll)inv[i + 1] * (i + 1) % mod;
int res = c(a1 + b1 - 2, a1 - 1);
g[0].push_back(1);
for (i = 1; i <= n; i++)
{
g[i].push_back(i - 1);
g[i].push_back(1);
}
solve(0, n - 1);
if (a1 + b1 - 2 < g[0].size()) res = (ll)res * g[0][a1 + b1 - 2] % mod;
else res = 0;
cout << res << endl;
return 0;
}