题面
题解
前置知识:
首先考虑朴素做法。设(f[u][i])表示点u取到第i大的权值的概率,然后进行树上dp。但是这个dp最多优化到(O(n^2)),并且时间和空间都是(O(n^2)),无法通过。
考虑使用线段树合并去实现这个dp的过程。对于树上的每一个节点u,开一棵线段树去记录(f[u])。具体实现是,在原树上进行dfs,到点u时:
- u是叶子节点:开一条新的链作为u的线段树。
- u有一个子节点:u的线段树就是u儿子的线段树。
- u有两个子节点:u的线段树就是u两个儿子线段树的并。
合并线段树时,只需要解决一棵线段树与一棵空树怎么(O(1))的合并。
假设原树上dfs到的节点是u,它的左右儿子节点是l、r。如果我们需要将x,y为根的两棵线段树合并(x,y是l,r对应的线段树上,两个相同位置的节点)而y为空,那么:
[forall x.l leq i leq x.r,f[u][i] = f[l][i](sumlimits_{j=1}^{i}f[r][j] imes p[u] + sumlimits_{j=i}^{wn}f[r][j]*(1-p[u]))
]
- 其中(x.l,x.r)表示x节点对应的线段的左右端点,(wn)表示不同权值的个数。
并且由于
[forall x.l leq i leq x.r,f[r][i] =0
]
所以有
(forall x.l leq i leq x.r,f[u][i] = f[l][i] imes (sumlimits_{i=1}^{x.l-1}f[r][j] imes p[u] + sumlimits_{i=x.r+1}^{wn}f[r][j] imes (1-p[u])))
因此,只需要在merge的过程中,时刻维护前缀和(pre_r=sum_{i=1}^{x.l-1}f[r][j])和后缀和(suf_r=sum_{i=x.r+1}^{wn}f[r][j])。
如果x不为空而y为空,只需要将x打上“整体乘((pre_r imes p[u] + suf_r imes (1-p[u])))”的标记即可。当然也可能出现x为空而(y)不为空的情况,相应地维护(pre_l)和(suf_l)就可以啦。
总时间复杂度(O(n log n))。
代码
#include<bits/stdc++.h>
using namespace std;
#define rg register
#define In inline
#define ll long long
const int N = 3e5;
const int TN = 2 * 19 * N;
const int mod = 998244353;
const int inv_10000 = 796898467;
namespace IO{
In int read(){
int s = 0,ww = 1;
char ch = getchar();
while(ch < '0' || ch > '9'){if(ch == '-')ww = -1;ch = getchar();}
while('0' <= ch && ch <= '9'){s = 10 * s + ch - '0';ch = getchar();}
return s * ww;
}
In void write(int x){
if(x < 0)x = -x,putchar('-');
if(x > 9)write(x / 10);
putchar('0' + x % 10);
}
}
using namespace IO;
namespace ModCalc{
In void Inc(int &x,int y){
x += y;if(x >= mod)x -= mod;
}
In void Dec(int &x,int y){
x -= y;if(x < 0)x += mod;
}
In void Tms(int &x,int y){
x = 1ll * x * y % mod;
}
In int Add(int x,int y){
Inc(x,y);return x;
}
In int Sub(int x,int y){
Dec(x,y);return x;
}
In int Mul(int x,int y){
Tms(x,y);return x;
}
}
using namespace ModCalc;
int rt[N+5],p[N+5],q[N+5];
int D[N+5];
int n;
struct SegTree{
int sum[TN+5],flag[TN+5],lc[TN+5],rc[TN+5];
int cnt;
In int newnode(){
cnt++;
flag[cnt] = 1;
return cnt;
}
In void pushdown(int u){
if(flag[u] == 1)return;
int L = lc[u],R = rc[u];
if(L){
Tms(sum[L],flag[u]);
Tms(flag[L],flag[u]);
}
if(R){
Tms(sum[R],flag[u]);
Tms(flag[R],flag[u]);
}
flag[u] = 1;
}
In void pushup(int u){
sum[u] = Add(sum[lc[u]],sum[rc[u]]);
}
int insert(int l,int r,int x,int d){
int u = newnode();
if(l == r){
sum[u] = d;
return u;
}
int m = (l + r) >> 1;
if(x <= m)lc[u] = insert(l,m,x,d);
else rc[u] = insert(m + 1,r,x,d);
pushup(u);
return u;
}
int merge(int u,int v,int cur,int pu,int su,int pv,int sv){ //u,v对应题解中的x,y;cur对应题解中的u;pu,su,pv,sv对应题解中的pre_l,suf_l,pre_r,suf_r
if(!u && !v)return 0;
if(!u){
int x = Add(Mul(pu,p[cur]),Mul(su,q[cur]));
Tms(sum[v],x);
Tms(flag[v],x);
return v;
}
if(!v){
int x = Add(Mul(pv,p[cur]),Mul(sv,q[cur]));
Tms(sum[u],x);
Tms(flag[u],x);
return u;
}
pushdown(u),pushdown(v);
int dpu = sum[lc[u]],dsu = sum[rc[u]],dpv = sum[lc[v]],dsv = sum[rc[v]];
lc[u] = merge(lc[u],lc[v],cur,pu,Add(su,dsu),pv,Add(sv,dsv));
rc[u] = merge(rc[u],rc[v],cur,Add(pu,dpu),su,Add(pv,dpv),sv);
pushup(u);
return u;
}
void dfs(int u,int l,int r){
if(l == r)D[l] = sum[u];
pushdown(u);
int m = (l + r) >> 1;
if(lc[u])dfs(lc[u],l,m);
if(rc[u])dfs(rc[u],m + 1,r);
}
}T;
vector<int>link[N+5];
int w[N+5],aw[N+5],wn;
void prepro(){ //离散化
for(rg int i = 1;i <= n;i++)if(!link[i].size())aw[++wn] = w[i];
sort(aw + 1,aw + wn + 1);
for(rg int i = 1;i <= n;i++)if(!link[i].size())
w[i] = lower_bound(aw + 1,aw + wn + 1,w[i]) - aw;
}
void dfs(int u){
if(!link[u].size())rt[u] = T.insert(1,wn,w[u],1);
else if(link[u].size() == 1){
dfs(link[u][0]);
rt[u] = rt[link[u][0]];
}
else{
int l = link[u][0],r = link[u][1];
dfs(l),dfs(r);
rt[u] = T.merge(rt[l],rt[r],u,0,0,0,0);
}
}
int main(){
n = read();
for(rg int i = 1;i <= n;i++)link[read()].push_back(i);
for(rg int i = 1;i <= n;i++){
int x = read();
if(!link[i].size())w[i] = x;
else p[i] = Mul(x,inv_10000),q[i] = Sub(1,p[i]);
}
prepro();
dfs(1);
T.dfs(rt[1],1,wn);
int ans = 0;
for(rg int i = 1;i <= wn;i++)Inc(ans,Mul(Mul(i,aw[i]),Mul(D[i],D[i])));
write(ans),putchar('
');
return 0;
}