链接
线段树合并,计算 (sum) 值时,要记录前缀和与后缀和,维护区间和与区间乘,
合并后的节点如果直接继承之前的节点的话要注意提前将值提取。
#include<bits/stdc++.h>
#define IL inline
#define LL long long
using namespace std;
const int N=3e5+3,p=998244353;
int n,m,k,num,inv,siz[N],P[N],val[N],b[N],ans,lc[N],rc[N];
int cnt,rt[N],ls[N*400],rs[N*400],mul[N*400],sum[N*400];
IL int in(){
char c;int f=1;
while((c=getchar())<'0'||c>'9')
if(c=='-') f=-1;
int x=c-'0';
while((c=getchar())>='0'&&c<='9')
x=x*10+c-'0';
return x*f;
}
IL int mod(int x){return x>=p?x-p:x;}
IL int ksm(int a,int b){
int c=1;
while(b){
if(b&1) c=1ll*c*a%p;
a=1ll*a*a%p,b>>=1;
}
return c;
}
IL void Mul(int o,int v){mul[o]=1ll*mul[o]*v%p,sum[o]=1ll*sum[o]*v%p;}
IL void pudn(int o){
if(mul[o]==1) return;
if(ls[o]) Mul(ls[o],mul[o]);
if(rs[o]) Mul(rs[o],mul[o]);
mul[o]=1;
}
void mdy(int &o,int l,int r,int u){
o=++cnt,sum[o]=1,mul[o]=1;
if(l==r) return;
int mid=l+r>>1;
if(u<=mid) mdy(ls[o],l,mid,u);
else mdy(rs[o],mid+1,r,u);
}
void merge(int P,int &o,int p1,int p2,int s1,int s2){
if(!p1) return Mul(o=p2,mod(1ll*s1*P%p+1ll*mod(1-s1+p)*mod(1-P+p)%p));
if(!p2) return Mul(o=p1,mod(1ll*s2*P%p+1ll*mod(1-s2+p)*mod(1-P+p)%p));
pudn(o=p1),pudn(p2);
merge(P,rs[o],rs[p1],rs[p2],mod(s1+sum[ls[p1]]),mod(s2+sum[ls[p2]]));//注意先递归右边的以避免值得更新
merge(P,ls[o],ls[p1],ls[p2],s1,s2);
sum[o]=mod(sum[ls[o]]+sum[rs[o]]);
}
void dfs(int u){
if(!lc[u]) return mdy(rt[u],1,num,lower_bound(b+1,b+num+1,val[u])-b);
if(!rc[u]) dfs(lc[u]),rt[u]=rt[lc[u]];
else dfs(lc[u]),dfs(rc[u]),merge(P[u],rt[u],rt[lc[u]],rt[rc[u]],0,0);
}
void calc(int o,int l,int r){
if(l==r){ans=mod(ans+1ll*l*b[l]%p*sum[o]%p*sum[o]%p);return;}
int mid=l+r>>1;pudn(o);
calc(ls[o],l,mid),calc(rs[o],mid+1,r);
}
int main()
{
int x,y;
n=in(),in(),inv=ksm(1e4,p-2);
for(int i=2;i<=n;++i){
x=in();
if(lc[x]) rc[x]=i;
else lc[x]=i;
}
for(int i=1;i<=n;++i)
if(lc[i]) P[i]=1ll*in()*inv%p;
else val[i]=b[++num]=in();
sort(b+1,b+num+1);
dfs(1),calc(rt[1],1,num);
printf("%d
",ans);
return 0;
}