【CF960G】Bandit Blues
题面
题解
思路和这道题一模一样,这里仅仅阐述优化的方法。
看看答案是什么:
[Ans=C(a+b-2,a-1)centerdot s(n-1,a+b-2)
]
组合数我们已经可以(O(N))求了,主要是第一类斯特林数存在问题。
考虑它的转移:
[s(n,m)=s(n-1,m-1)+(n-1)*s(n-1,m)
]
根据这个转移,我们写出它(n)固定时的生成函数
[G(x)=prod_{i=0}^{n-1}(x+i)
]
然后每一个(s(n,m))就是升序第(m)项的次数。
为什么生成函数是这个?
引用(yyb)的:
把(n)为定值时的所有的第一类斯特林数按照(n)分类分成行,发现每次的(s(n,m))转移必定要从(n−1)行转移过来,而每次转移都是(m−1)变到(m),系数为(1),因此有一项(x),同理有一项(n−1),因此就可以得到上面的那个生成函数。
然后对于这个东西我们用分治+(NTT)就可以(O(nlog^2))地做了。
代码
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cmath>
#include <algorithm>
#include <vector>
using namespace std;
const int Mod = 998244353;
int fpow(int x, int y) {
int res = 1;
while (y) {
if (y & 1) res = 1ll * res * x % Mod;
y >>= 1;
x = 1ll * x * x % Mod;
}
return res;
}
const int G = 3, iG = fpow(G, Mod - 2);
const int MAX_N = 3e5 + 5;
int rev[MAX_N], Limit;
void NTT(vector<int> &p, int op) {
for (int i = 0; i < Limit; i++) if (i < rev[i]) swap(p[i], p[rev[i]]);
for (int i = 1; i < Limit; i <<= 1) {
int rot = fpow(op == 1 ? G : iG, (Mod - 1) / (i << 1));
for (int j = 0; j < Limit; j += (i << 1)) {
int w = 1;
for (int k = 0; k < i; k++, w = 1ll * w * rot % Mod) {
int x = p[j + k], y = 1ll * w * p[i + j + k] % Mod;
p[j + k] = (x + y) % Mod, p[i + j + k] = (x - y + Mod) % Mod;
}
}
}
if (op == -1) {
int inv = fpow(Limit, Mod - 2);
for (int i = 0; i < Limit; i++) p[i] = 1ll * inv * p[i] % Mod;
}
}
vector<int> mul(vector<int> &A, vector<int> &B) {
static vector<int> C;
C.clear();
int p = 0, sz = A.size() + B.size() - 2;
for (Limit = 1; Limit <= sz + 1; Limit <<= 1, ++p);
for (int i = 0; i < Limit; i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (p - 1));
A.resize(Limit), B.resize(Limit);
NTT(A, 1), NTT(B, 1);
for (int i = 0; i < Limit; i++) C.push_back(1ll * A[i] * B[i] % Mod);
NTT(C, -1);
return C;
}
vector<int> Div(int l, int r) {
vector<int> L, R;
if (l == r) return {l, 1};
int mid = (l + r) >> 1;
L = Div(l, mid), R = Div(mid + 1, r);
return mul(L, R);
}
int fac(int x) { int res = 1; for (int i = 1; i <= x; i++) res = 1ll * res * i % Mod; return res; }
int C(int n, int m) {
if (m > n) return 0;
else return 1ll * fac(n) * fpow(fac(m), Mod - 2) % Mod * fpow(fac(n - m), Mod - 2) % Mod;
}
vector<int> Ans;
int main () {
#ifndef ONLINE_JUDGE
freopen("cpp.in", "r", stdin);
#endif
int N, A, B;
cin >> N >> A >> B;
if (!A || !B || A + B - 2 > N - 1) return puts("0") & 0;
if (N == 1) return puts("1") & 0;
Ans = Div(0, N - 2);
printf("%d
", (int)(1ll * Ans[A + B - 2] * C(A + B - 2, A - 1) % Mod));
return 0;
}