题目链接:https://ac.nowcoder.com/acm/contest/11255/G
题目大意:计算(frac{D!}{prod_{i,a_igeqslant0,sum a_i = D}^{n}(a_i+k)!})
题目思路:
首先考虑(frac{D!}{prod_{isum a_i = D}^{n}a_i!})
参考大佬(Alkaid~)的证明
这个可以看做(D)个不同的球分为(n)组的情况,以拿球放球为例:把每组看成一个栈,假设(D=10,n=4,a=egin{Bmatrix}2,4,3,1end{Bmatrix}),我们对每个球编号,取(a_3=egin{Bmatrix}1,4,7end{Bmatrix}),有一种拿球顺序为(egin{Bmatrix}1,7,4,8,5,3,2,6,10,9end{Bmatrix}),那么(a_3)的放球顺序为(egin{Bmatrix}4,7,1end{Bmatrix}),另一种拿球顺序为(egin{Bmatrix}7,8,10,3,5,1,9,2,4,6end{Bmatrix}),那么(a_3)的放球顺序为(egin{Bmatrix}4,1,7end{Bmatrix}),总共有(D!)种拿球顺序,那么对于(a_3)就有(a_3!)种放球顺序,但实际上所有(a_3)都是情况,因此需要(frac{D!}{a_3!}),那么对于所有的情况就有(frac{D!}{prod_{isum a_i = D}^{n}a_i!}),而每个球都有(n)种去向,因此
由此可得
如果不考虑(a_i<k)的情况,答案
现在需要剔除所有(a_i<k)的情况,可以用容斥来做
设(f_{i,j})表示(j)个球分为(i)个不合法组的方案数,那么通过枚举第(j)组的个数和球的分配情况,可以得到递推式
这里将(j-t)个球分配给前(i-1)的方案数前面已经计算过,即(f_{i-1,j-t}),因此后面直接乘(C_{j}^{t})即可
感谢大佬(Spy_Savior)的答疑
最后通过容斥得到
其中(C_{n}^{i})是取哪些组为不合法,(C_{D+nk}^{j})是取哪些球放到这些不合法组里,这个数不能预处理,只能每次遍历时递推计算,((n-i)^{D+nk-j})即剩下(D+nk-j)个球分为(n-i)组的合法情况
AC代码:
#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdio>
#include <vector>
#include <string>
#include <stack>
#include <deque>
#include <queue>
#include <cmath>
#include <map>
#include <set>
using namespace std;
//typedef __int128 int128;
typedef pair<int, int> PII;
typedef long long ll;
typedef unsigned long long ull;
const int INF = 0x3f3f3f3f;
const int N = 4e3 + 10, M = 1e5;
const int base = 1e9;
const int P = 131;
const int mod = 998244353;
const double eps = 1e-6;
ll n, k, D;
ll C[N][N], f[N][N];
ll ksm(ll a, ll b)
{
ll res = 1 % mod;
while (b)
{
if (b & 1)
res = res * a % mod;
a = a * a % mod;
b >>= 1;
}
return res;
}
void init()
{
for (int i = 0; i < N; ++i)
for (int j = 0; j <= i; ++j)
{
if (j == 0)
C[i][j] = 1;
else
C[i][j] = (C[i - 1][j] + C[i - 1][j - 1]) % mod;
}
f[0][0] = 1;
for (int i = 0; i < n; ++i)
for (int j = 0; j <= i * (k - 1); ++j)
for (int t = 0; t < k; ++t)
f[i + 1][j + t] = (f[i + 1][j + t] + f[i][j] * C[j + t][t] % mod) % mod;
}
int main()
{
scanf("%lld%lld%lld", &n, &k, &D);
init();
ll ans = 0;
for (int i = 0; i <= n; ++i)
{
ll CDJ = 1;
for (int j = 0; j <= i * (k - 1); ++j)
{
ll res = i & 1 ? -f[i][j] : f[i][j];
res = res * C[n][i] % mod * CDJ % mod * ksm(n - i, D + n * k - j) % mod;
ans = (ans + res) % mod;
CDJ = CDJ * (D + n * k - j) % mod * ksm(j + 1, mod - 2) % mod;
}
}
for (ll i = D + 1; i <= D + n * k; ++i)
ans = ans * ksm(i, mod - 2) % mod;
ans = (ans % mod + mod) % mod;
printf("%lld
", ans);
return 0;
}