$O(n^2)$ 的式子是好列的,然后我们发现这是一个关于前后缀的转移.
用线段树合并优化这一过程.
具体地,分别维护 $x,y$ 的后缀和.
这里要注意:由于这道题中两个不同子树肯定没有交集,所以在线段树合并的时候肯定会合并到一个点,使得两个树中一个为空.
然后由于另一个是空的,就没有合并的必要了,这样整个区间乘的就是一个相同的数了.
这样就只需要维护一个乘法标记就行了.
code:
#include <bits/stdc++.h> #define ll long long #define mod 998244353 #define N 300008 #define lson s[x].ls #define rson s[x].rs #define setIO(s) freopen(s".in","r",stdin) using namespace std; int fa[N],cn,n,tot,ans; int ch[N][2],val[N],perc[N],A[N],v[N]; int qpow(int x,int y) { int tmp=1; for(;y;y>>=1,x=(ll)x*x%mod) if(y&1) tmp=(ll)tmp*x%mod; return tmp; } inline int INV(int x) { return qpow(x,mod-2); } struct data { int ls,rs; ll sum,tag; // data(){ ls=rs=sum=0,tag=1; } }s[N*50]; int rt[N]; inline int newnode() { return ++tot; } inline void pushup(int x) { s[x].sum=(ll)(s[lson].sum+s[rson].sum)%mod; } void update(int &x,int l,int r,int p,int v) { if(!x) x=newnode(),s[x].tag=1; if(l==r) { s[x].sum=s[x].tag=v; return; } int mid=(l+r)>>1; if(p<=mid) update(lson,l,mid,p,v); else update(rson,mid+1,r,p,v); pushup(x); } inline void mark(int x,ll v) { s[x].tag=(ll)s[x].tag*v%mod; s[x].sum=(ll)s[x].sum*v%mod; } inline void pushdown(int x) { if(s[x].tag!=1) { if(lson) mark(lson,s[x].tag); if(rson) mark(rson,s[x].tag); s[x].tag=1; } } // s1-> 小的 // s2-> 多的 int merge(int x,int y,ll det,ll x1,ll x2,ll y1,ll y2) { if(!x&&!y) return 0; if(!x) { ll up=(ll)((ll)(1-det+mod)%mod*x2%mod+(ll)det*x1%mod)%mod; mark(y,up); return y; } if(!y) { ll up=(ll)((ll)(1-det+mod)%mod*y2%mod+(ll)det*y1%mod)%mod; mark(x,up); return x; } int now=newnode(); pushdown(x),pushdown(y); int xr=(ll)(x2+s[s[x].rs].sum)%mod; int yr=(ll)(y2+s[s[y].rs].sum)%mod; int xl=(ll)(x1+s[s[x].ls].sum)%mod; int yl=(ll)(y1+s[s[y].ls].sum)%mod; s[now].tag=1; s[now].ls=merge(s[x].ls,s[y].ls,det,x1,xr,y1,yr); s[now].rs=merge(s[x].rs,s[y].rs,det,xl,x2,yl,y2); pushup(now); return now; } void dfs(int x) { int l=ch[x][0],r=ch[x][1]; if(!l) update(rt[x],1,cn,v[x],1); else if(!r) dfs(l),rt[x]=rt[l]; else dfs(l),dfs(r),rt[x]=merge(rt[l],rt[r],1ll*perc[x],0,0,0,0); } void output(int x,int l,int r) { if(!x) return; if(l==r) { (ans+=(ll)l*A[l]%mod*s[x].sum%mod*s[x].sum%mod)%=mod; return; } int mid=(l+r)>>1; pushdown(x); output(s[x].ls,l,mid); output(s[x].rs,mid+1,r); } int main() { // setIO("input"); scanf("%d",&n); for(int i=1;i<=n;++i) { scanf("%d",&fa[i]); if(ch[fa[i]][0]) ch[fa[i]][1]=i; else ch[fa[i]][0]=i; } for(int i=1;i<=n;++i) { int a; scanf("%d",&a); if(!ch[i][0]) val[i]=a,A[++cn]=val[i]; else perc[i]=(ll)a*INV(10000)%mod; } sort(A+1,A+1+cn); for(int i=1;i<=n;++i) if(!ch[i][0]) v[i]=lower_bound(A+1,A+1+cn,val[i])-A; dfs(1),output(rt[1],1,cn),printf("%d ",ans); return 0; }