题面
题解
考虑一个数字被取到最小值的概率怎么算。
由于一个节点最多只有 (2) 个儿子,所以 (x) 出现的概率 (a_x) 分为两个部分,一个作为最大值,另一个即作为最小值。
以计算这个点作为最小值出现的概率为例,这个概率就是这个数在这棵子树内出现的概率 (a_x') 乘以另外一棵子树中取到比它大的数字的概率再乘上这个点取最小值的概率 (1 - p_x)。最大值同理。
于是对每个点维护一棵线段树维护每个点出现的概率,线段树合并即可。
代码
#include <cstdio>
#include <algorithm>
#include <vector>
inline int read()
{
int data = 0, w = 1; char ch = getchar();
while (ch != '-' && (ch < '0' || ch > '9')) ch = getchar();
if (ch == '-') w = -1, ch = getchar();
while (ch >= '0' && ch <= '9') data = data * 10 + (ch ^ 48), ch = getchar();
return data * w;
}
const int N(3e5 + 10), Mod(998244353), Inv(796898467), LIM(1e9);
struct edge { int next, to; } e[N << 1];
int n, m, w[N], head[N], e_num, rt[N], cnt;
int cur, son[2][N * 20], tag[N * 20], p[N * 20], ans;
inline void add_edge(int from, int to)
{
e[++e_num] = (edge) {head[from], to};
head[from] = e_num;
}
void pushup(int x) { p[x] = (p[son[0][x]] + p[son[1][x]]) % Mod; }
void pushtag(int x, int _)
{ p[x] = 1ll * p[x] * _ % Mod, tag[x] = 1ll * tag[x] * _ % Mod; }
void pushdown(int x)
{
if (tag[x] == 1) return;
pushtag(son[0][x], tag[x]);
pushtag(son[1][x], tag[x]);
tag[x] = 1;
}
void Solve(int x, int l = 1, int r = LIM)
{
if (!x) return;
if (l == r) return (void) (ans = (ans + 1ll * (++cnt) * l % Mod * p[x] % Mod * p[x]) % Mod);
int mid = (l + r) >> 1; pushdown(x);
Solve(son[0][x], l, mid), Solve(son[1][x], mid + 1, r);
}
void Insert(int &x, int t, int l = 1, int r = LIM)
{
if (!x) x = ++cur; p[x] = tag[x] = 1;
if (l == r) return; int mid = (l + r) >> 1;
if (t <= mid) Insert(son[0][x], t, l, mid);
else Insert(son[1][x], t, mid + 1, r);
pushup(x);
}
int Merge(int x, int y, int px, int py, int t)
{
if (!x && !y) return 0;
if (!x) return pushtag(y, px), y;
if (!y) return pushtag(x, py), x;
pushdown(x), pushdown(y);
int pxl, pxr, pyl, pyr;
pxl = (px + 1ll * (Mod + 1 - t) * p[son[1][x]]) % Mod;
pyl = (py + 1ll * (Mod + 1 - t) * p[son[1][y]]) % Mod;
pxr = (px + 1ll * t * p[son[0][x]]) % Mod;
pyr = (py + 1ll * t * p[son[0][y]]) % Mod;
son[0][x] = Merge(son[0][x], son[0][y], pxl, pyl, t);
son[1][x] = Merge(son[1][x], son[1][y], pxr, pyr, t);
return pushup(x), x;
}
void dfs(int x)
{
int a[2], tot = 0;
for (int i = head[x]; i; i = e[i].next)
dfs(e[i].to), a[tot++] = rt[e[i].to];
if (tot == 0) Insert(rt[x], w[x]);
else if (tot == 1) rt[x] = a[0];
else rt[x] = Merge(a[0], a[1], 0, 0, 1ll * w[x] * Inv % Mod);
}
int main()
{
n = read();
for (int i = 1; i <= n; i++) add_edge(read(), i);
for (int i = 1; i <= n; i++) w[i] = read();
dfs(1), Solve(rt[1]), printf("%d
", ans);
return 0;
}