LOJ2541 「PKUWC2018」猎人杀
题目大意
(n) 个猎人,每个猎人有一个权值 (w_i)。每个猎人死去后,会开枪打死一个还活着的猎人。假设当前还活着的猎人为 ({i_1, dots, i_m}),那么有 (frac{w_{i_{k}}}{sum_{j = 1}^{m}w_{i_{j}}}) 的概率向 (i_k) 开枪。第一枪由你打响,目标的选择方法和猎人一样。由于开枪导致的连锁反应,所有猎人最终都会死亡。请求出 (1) 号猎人最后一个死的的概率。答案对 (998244353) 取模。
数据范围 (w_i > 0),(1leq sum_{i = 1}^{n}w_ileq 10^5)。
前置知识
- (forall x in[0, 1):sum_{i = 0}^{infty} x^i = frac{1}{1 - x})。
- 给定 ({a_i}, {b_i}, {c_i}) 序列,求关于 (x) 的多项式:(prod_{i = 1}^{n}(c_i + a_ix^{b_i})),其中 (nleq B = sum b_ileq 10^5)。可以用分治 FFT 做。时间复杂度 (mathcal{O}(Blog^2 B))。
本题题解
初步转化:
因为每个猎人死后,概率的分母也会改变,这给我们的计算带来了极大的不便。于是考虑不改变分母:也就是不管死了多少个猎人,我们开枪时仍然在 (n) 个猎人里进行选择,如果选到了已经死去的猎人,就假装无事发生,再选一次,直到某次选中了活着的猎人为止。
结论:这样转化后,每个还活着的猎人被选中的概率不会变化。
以下是简单的证明。设当前还活着的人的集合为 (A = {a_1, dots, a_{|A|}}),已经死去的人的集合为 (D = {d_1, dots, d_{|D|}})。设 (W = sum_{i = 1}^{n}w_i)(即所有猎人的权值之和),(T = sum_{i in D} w_i)(即所有已死的猎人的权值之和)。设原问题中,杀死当前还活着的人 (a_k) 的概率为 (P_1),转化后杀死他的概率为 (P_2)。显然:
对 (P_2),可以枚举在选到活着的猎人之前,选了几次已死的猎人。则:
容斥:
接下来做这个初步转化后的问题。考虑容斥,设集合 (S) 里的这些人死的时间晚于 (1)((Ssubseteq{2, dots, n})),其他人死的时间不限(可以早于 (1) 也可以晚于 (1))。我们要计算这种情况发生的概率,然后乘以 ((-1)^{|S|}),累加进答案。
设 (mathrm{sum}(S) = sum_{iin S} w_i)(即 (S) 集合里点的权值和)。则:
因为 (frac{W - mathrm{sum}(S) - w_1}{W}in[0, 1)),故考虑使用公式:(forall x in[0, 1):sum_{i = 0}^{infty} x^i = frac{1}{1 - x})。则:
设 (f_i) 表示 (mathrm{sum}(S) = i) 的集合 (S) 的 ((-1)^{|S|}) 之和。则:
问题转化为求 (f) 数组。
生成函数:
考虑把一组 (w) 贡献到 (f) 里的过程:下标相加,系数相乘。容易联想到生成函数。
具体来说,构造 (F(x) = sum_{i = 0}^{infty} f_i x^{i})。则:
因为 (sum w_ileq 10^5),上式可以用分治 NTT 求出。时间复杂度 (mathcal{O}(nlog ^2 n))(此处认为 (W, n) 同阶)。
参考代码
实际提交时建议使用输入、输出优化,详见本博客公告。
// problem: LOJ2541
#include <bits/stdc++.h>
using namespace std;
#define mk make_pair
#define fi first
#define se second
#define SZ(x) ((int)(x).size())
typedef unsigned int uint;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> pii;
template<typename T> inline void ckmax(T &x, T y) {
x = (y > x ? y : x);
}
template<typename T> inline void ckmin(T &x, T y) {
x = (y < x ? y : x);
}
const int MAXN = 1e5;
const int MOD = 998244353;
inline int mod1(int x) {
return x < MOD ? x : x - MOD;
}
inline int mod2(int x) {
return x < 0 ? x + MOD : x;
}
inline void add(int &x, int y) {
x = mod1(x + y);
}
inline void sub(int &x, int y) {
x = mod2(x - y);
}
inline int pow_mod(int x, int i) {
int y = 1;
while (i) {
if (i & 1)
y = (ll)y * x % MOD;
x = (ll)x * x % MOD;
i >>= 1;
}
return y;
}
namespace PolyNTT {
int rev[MAXN * 4 + 5];
int f[MAXN * 4 + 5], g[MAXN * 4 + 5];
void NTT(int *a, int n, int flag) {
for (int i = 0; i < n; ++i)
if (i < rev[i])
swap(a[i], a[rev[i]]);
for (int i = 1; i < n; i <<= 1) {
int T = pow_mod(3, (MOD - 1) / (i << 1));
if (flag == -1)
T = pow_mod(T, MOD - 2);
for (int j = 0; j < n; j += (i << 1)) {
for (int k = 0, t = 1; k < i; ++k, t = (ll)t * T % MOD) {
int Nx = a[j + k], Ny = (ll)a[i + j + k] * t % MOD;
a[j + k] = mod1(Nx + Ny);
a[i + j + k] = mod2(Nx - Ny);
}
}
}
if (flag == -1) {
int invn = pow_mod(n, MOD - 2);
for (int i = 0; i < n; ++i)
a[i] = (ll)a[i] * invn % MOD;
}
}
void mul(int n, int m) {
int lim = 1, ct = 0;
while (lim <= n + m)
lim <<= 1, ct++;
for (int i = n; i <= lim; ++i)
f[i] = 0;
for (int i = m; i <= lim; ++i)
g[i] = 0; //clear
for (int i = 0; i < lim; ++i)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (ct - 1));
NTT(f, lim, 1);
NTT(g, lim, 1);
for (int i = 0; i < lim; ++i)
f[i] = (ll)f[i] * g[i] % MOD;
NTT(f, lim, -1);
}
}//namespace PolyNTT
typedef vector<int> Poly;
Poly operator*(const Poly &a, const Poly &b) {
if (!SZ(a) && !SZ(b))
return Poly();
Poly res;
res.resize(SZ(a) + SZ(b) - 1);
if (SZ(a) <= 50 && SZ(b) <= 50) {
for (int i = 0; i < SZ(a); ++i)
for (int j = 0; j < SZ(b); ++j)
add(res[i + j], (ll)a[i]*b[j] % MOD);
return res;
}
for (int i = 0; i < SZ(a); ++i)
PolyNTT::f[i] = a[i];
for (int i = 0; i < SZ(b); ++i)
PolyNTT::g[i] = b[i];
PolyNTT::mul(SZ(a), SZ(b));
for (int i = 0; i < SZ(res); ++i)
res[i] = PolyNTT::f[i];
return res;
}
Poly &operator*=(Poly &lhs, const Poly &rhs) {
lhs = lhs * rhs;
return lhs;
}
Poly operator+(const Poly &a, const Poly &b) {
Poly res;
res.resize(max(SZ(a), SZ(b)));
for (int i = 0; i < SZ(res); ++i) {
res[i] = mod1((i >= SZ(a) ? 0 : a[i]) + (i >= SZ(b) ? 0 : b[i]));
}
return res;
}
Poly &operator+=(Poly &lhs, const Poly &rhs) {
lhs = lhs + rhs;
return lhs;
}
int n, w[MAXN + 5], W;
Poly solve(int l, int r) {
if (l == r) {
Poly res;
res.resize(w[l] + 1);
res[0] = 1;
res[w[l]] = MOD - 1;
return res;
}
int mid = (l + r) >> 1;
return solve(l, mid) * solve(mid + 1, r);
}
int main() {
cin >> n;
if (n == 1) {
cout << 1 << endl;
return 0;
}
for (int i = 1; i <= n; ++i) {
cin >> w[i];
W += w[i];
}
Poly f = solve(2, n);
assert(SZ(f) == W - w[1] + 1);
int ans = 0;
for (int i = 0; i <= W - w[1]; ++i) {
add(ans, (ll)f[i] * pow_mod(i + w[1], MOD - 2) % MOD);
}
ans = (ll)ans * w[1] % MOD;
cout << ans << endl;
return 0;
}