看到恰好 (k) 对这样的限制,我们考虑容斥:设 (g_k) 表示至少有 (k) 对的方案,(ans_k) 表示恰好有 (k) 对的方案:
[g_k = sum_{i = k} ^ {n - 1} {i choose k} ans_i Rightarrow ans_k = sum_{i = k} ^ {n - 1} (-1) ^ {i - k} { i choose k } g_i
]
然后我们考虑如何计算 (g_k)。
发现同种类型的牌都是一样的,不好算方案,我们不妨先认为同种类型的牌互不相同,那么在最后的答案中我们除以 (prod_{i = 1} ^ m a_i !) 就好了(可重集排列)
我们考虑 ( ext{DP}) :设 (dp_{i, j}) 表示用前 (i) 种类型的牌,组成了至少 (j) 对的方案,为了方便转移,我们用 (f_{i, j}) 表示用第 (i) 种类型的牌,组成了至少 (j) 对的方案。
考虑 (f_{i, j}) 怎么计算,由于我们之前已经钦定了每张牌都不相同,所以我们可以先从 (a_i) 中牌中选 (a_i - j) 张作为不直接参与答案计算的牌,然后考虑把剩下的 (j) 张插入到这些牌的左边,显然这样至少会有 (j) 对。那么对于这 (j) 张牌,第一张有 (a_i - j) 张牌可以插,第二张有 (a_i - j + 1) 张可以插,以此类推不难得到 (f_{i, j}) 的表达式:
[f_{i, j} = {a_i choose a_i - j} (a_i - 1)^{underline j}
]
转移方程很显然:(dp_{i, j} = sum_{k = 0} ^ {j} dp_{i - 1, j - k} imes f_{i, k})
不难做到用分治 ( ext{NTT})(雾)优化求 (dp_{i, j}) 的过程。
但是这里要注意一点就是我们求出来的 (dp_{m, k}) 和 (g_k) 是不等价的,因为我们对于和魔术对无关的 (n - k) 张牌是可以随便排的,也就是说 (g_k = dp_{m, k} imes (n - k)!)
最后我们就可以直接计算 (ans_k) 了。
参考代码:
#include <algorithm>
#include <cstdio>
#include <vector>
using namespace std;
const int _ = 4e5 + 5, p = 998244353, G = 3, iG = 332748118;
template < class T > void read(T& s) {
s = 0; int f = 0; char c = getchar();
while ('0' > c || c > '9') f |= c == '-', c = getchar();
while ('0' <= c && c <= '9') s = s * 10 + c - 48, c = getchar();
s = f ? -s : s;
}
int n, m, k, a[_], fac[_], ifc[_], r[_];
vector < int > f[_], g;
int power(int x, int k) {
int res = 1;
for (; k; k >>= 1, x = 1ll * x * x % p)
if (k & 1) res = 1ll * res * x % p;
return res % p;
}
int C(int N, int M) { return 1ll * fac[N] * ifc[M] % p * ifc[N - M] % p; }
void NTT(vector < int > & A, int N, int type) {
for (int i = 0; i < N; ++i) if (i < r[i]) swap(A[i], A[r[i]]);
for (int i = 1; i < N; i <<= 1) {
int Wn = power(type ? G : iG, (p - 1) / (i << 1));
for (int j = 0; j < N; j += i << 1)
for (int k = 0, w = 1; k < i; ++k, w = 1ll * w * Wn % p) {
int x = A[j + k], y = 1ll * w * A[j + i + k] % p;
A[j + k] = (x + y) % p, A[j + i + k] = (x - y + p) % p;
}
}
if (!type) {
int inv = power(N, p - 2);
for (int i = 0; i < N; ++i) A[i] = 1ll * A[i] * inv % p;
}
}
vector < int > solve(int L, int R) {
if (L == R) return f[L];
int mid = (L + R) >> 1;
vector < int > A = solve(L, mid);
vector < int > B = solve(mid + 1, R);
int N = 1, l = 0;
while (N <= A.size() + B.size()) N <<= 1, ++l;
for (int i = 0; i < N; ++i)
r[i] = (r[i >> 1] >> 1) | ((i & 1) << (l - 1));
while (A.size() != N) A.push_back(0);
while (B.size() != N) B.push_back(0);
NTT(A, N, 1), NTT(B, N, 1);
for (int i = 0; i < N; ++i) A[i] = 1ll * A[i] * B[i] % p;
NTT(A, N, 0);
while (A.size() && !A.back()) A.pop_back();
return A;
}
int main() {
#ifndef ONLINE_JUDGE
freopen("cpp.in", "r", stdin), freopen("cpp.out", "w", stdout);
#endif
read(m), read(n), read(k);
for (int i = 1; i <= m; ++i) read(a[i]);
fac[0] = 1;
for (int i = 1; i <= n; ++i) fac[i] = 1ll * fac[i - 1] * i % p;
ifc[n] = power(fac[n], p - 2);
for (int i = n; i; --i) ifc[i - 1] = 1ll * ifc[i] * i % p;
for (int i = 1; i <= m; ++i)
for (int j = 0; j < a[i]; ++j)
f[i].push_back(1ll * C(a[i], j) * fac[a[i] - 1] % p * ifc[a[i] - j - 1] % p);
g = solve(1, m);
for (int i = 0; i < g.size(); ++i) g[i] = 1ll * g[i] * fac[n - i] % p;
int ans = 0;
for (int i = k; i < g.size(); ++i) {
int tmp = 1ll * C(i, k) * g[i] % p;
if (i - k & 1) ans = (ans - tmp + p) % p;
else ans = (ans + tmp) % p;
}
for (int i = 1; i <= m; ++i) ans = 1ll * ans * ifc[a[i]] % p;
printf("%d
", ans);
return 0;
}