感觉我去pkuwc好像只有爆零的份……
设(f_{u,i})表示(u)取到(i)的概率,那么有如下转移
[f_{u,i}=f_{ls,i}(p_usum_{j<i}f_{rs,j}+(1-p_u)sum_{j>i}f_{rs,j})+\f_{rs,i}(p_usum_{j<i}f_{ls,j}+(1-p_u)sum_{j>i}f_{ls,j})
]
然后用线段树合并即可,最后在根节点的线段树上(dfs)统计答案
//minamoto
#include<bits/stdc++.h>
#define R register
#define fp(i,a,b) for(R int i=a,I=b+1;i<I;++i)
#define fd(i,a,b) for(R int i=a,I=b-1;i>I;--i)
#define go(u) for(int i=head[u],v=e[i].v;i;i=e[i].nx,v=e[i].v)
using namespace std;
char buf[1<<21],*p1=buf,*p2=buf;
inline char getc(){return p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++;}
int read(){
R int res,f=1;R char ch;
while((ch=getc())>'9'||ch<'0')(ch=='-')&&(f=-1);
for(res=ch-'0';(ch=getc())>='0'&&ch<='9';res=res*10+ch-'0');
return res*f;
}
const int N=3e5+5,P=998244353,inv=796898467;
inline int add(R int x,R int y){return x+y>=P?x+y-P:x+y;}
inline int dec(R int x,R int y){return x-y<0?x-y+P:x-y;}
inline int mul(R int x,R int y){return 1ll*x*y-1ll*x*y/P*P;}
int ksm(R int x,R int y){
R int res=1;
for(y;y>>=1;x=mul(x,x))if(y&1)res=mul(res,x);
return res;
}
struct node{int ls,rs,sum,tag;}e[N*20];
int fa[N],ch[N][2],rt[N],w[N],b[N],tot,n,pi,m;
inline void ppd(R int p,R int x){e[p].sum=mul(e[p].sum,x),e[p].tag=mul(e[p].tag,x);}
void pd(R int p){
if(e[p].tag!=1){
ppd(e[p].ls,e[p].tag),ppd(e[p].rs,e[p].tag);
e[p].tag=1;
}
}
void ins(int &p,int l,int r,int x){
if(!p)p=++tot,e[p].sum=e[p].tag=1;if(l==r)return;
int mid=(l+r)>>1;
x<=mid?ins(e[p].ls,l,mid,x):ins(e[p].rs,mid+1,r,x);
}
int merge(int x,int y,int sumx,int sumy){
if(!x)return ppd(y,sumx),y;
if(!y)return ppd(x,sumy),x;
pd(x),pd(y);
int x1=e[e[x].ls].sum,x2=e[e[x].rs].sum,y1=e[e[y].ls].sum,y2=e[e[y].rs].sum;
e[x].ls=merge(e[x].ls,e[y].ls,add(sumx,mul(dec(1,pi),x2)),add(sumy,mul(dec(1,pi),y2)));
e[x].rs=merge(e[x].rs,e[y].rs,add(sumx,mul(pi,x1)),add(sumy,mul(pi,y1)));
e[x].sum=add(e[e[x].ls].sum,e[e[x].rs].sum);
return x;
}
int solve(int p){
if(!ch[p][0])return ins(rt[p],1,m,lower_bound(b+1,b+1+m,w[p])-b),rt[p];
int rtl=solve(ch[p][0]);if(!ch[p][1])return rtl;
int rtr=solve(ch[p][1]);
pi=w[p];return merge(rtl,rtr,0,0);
}
int calc(int p,int l,int r){
if(l==r)return mul(l,mul(b[l],mul(e[p].sum,e[p].sum)));
pd(p);int mid=(l+r)>>1;
return add(calc(e[p].ls,l,mid),calc(e[p].rs,mid+1,r));
}
int main(){
// freopen("testdata.in","r",stdin);
n=read();fp(i,1,n){
fa[i]=read();
ch[fa[i]][0]?ch[fa[i]][1]=i:ch[fa[i]][0]=i;
}fp(i,1,n){
w[i]=read();
ch[i][0]?w[i]=mul(w[i],inv):b[++m]=w[i];
}sort(b+1,b+1+m),m=unique(b+1,b+1+m)-b-1;
printf("%d
",calc(solve(1),1,m));
return 0;
}