是个傻题
显然枚举每一条路径经过了多少次,如果(u,v)在树上不是祖先关系的话经过((u,v))这条路径的路径条数就是(sum_u imes sum_v)
于是我们子树大小映射到( m Trie)上去,树形( m dp)一下就可以求出所有点对产生的贡献了
但是这样祖先关系的节点就算错了,我们发现这也非常好算,( m dfs)的时候拿( m LCT)维护一下就好了
代码
#include<bits/stdc++.h>
#define re register
inline int read() {
char c=getchar();int x=0;while(c<'0'||c>'9') c=getchar();
while(c>='0'&&c<='9') x=(x<<3)+(x<<1)+c-48,c=getchar();return x;
}
const int maxn=3e5+5;
const int mod=998244353;
struct E{int v,nxt;}e[maxn];
inline int qm(int x) {return x>=mod?x-mod:x;}
inline int dqm(int x) {return x<0?x+mod:x;}
int n,m,num,rt,ans,sm[maxn],head[maxn],d[maxn];
char S[maxn];
struct Trie {
E e[maxn<<1];
int head[maxn],num,v[maxn],deep[maxn];
inline void add(int x,int y) {
e[++num].v=y;e[num].nxt=head[x];head[x]=num;
}
void pdfs(int x) {
for(re int i=head[x];i;i=e[i].nxt) deep[e[i].v]=deep[x]+1,pdfs(e[i].v);
}
void dfs(int x,int dep) {
for(re int i=head[x];i;i=e[i].nxt) {
dfs(e[i].v,dep+1);
ans=qm(ans+1ll*dep*v[x]%mod*v[e[i].v]%mod);
v[x]=qm(v[x]+v[e[i].v]);
}
}
}T;
struct LinkCutTree {
int fa[maxn],ch[maxn][2],rev[maxn],tag[maxn],st[maxn],top,sum[maxn],a[maxn],sz[maxn];
inline int nrt(int x) {return ch[fa[x]][1]==x||ch[fa[x]][0]==x;}
inline void pushup(int x) {
sz[x]=1+sz[ch[x][0]]+sz[ch[x][1]];sum[x]=a[x];
if(ch[x][0]) sum[x]=qm(sum[x]+sum[ch[x][0]]);
if(ch[x][1]) sum[x]=qm(sum[x]+sum[ch[x][1]]);
}
inline void work(int x,int v) {
a[x]=qm(a[x]+v);tag[x]=qm(tag[x]+v);
sum[x]=qm(sum[x]+1ll*sz[x]*v%mod);
}
inline void pushdown(int x) {
if(tag[x]) {
if(ch[x][0]) work(ch[x][0],tag[x]);
if(ch[x][1]) work(ch[x][1],tag[x]);
tag[x]=0;
}
if(rev[x]) {
rev[x]=0;rev[ch[x][0]]^=1;rev[ch[x][1]]^=1;
std::swap(ch[ch[x][0]][0],ch[ch[x][0]][1]);
std::swap(ch[ch[x][1]][0],ch[ch[x][1]][1]);
}
}
inline void rotate(int x) {
int y=fa[x],z=fa[y],w=ch[y][1]==x,k=ch[x][w^1];
if(nrt(y)) ch[z][ch[z][1]==y]=x;
ch[x][w^1]=y,ch[y][w]=k;
pushup(y),pushup(x);fa[k]=y,fa[y]=x,fa[x]=z;
}
inline void splay(int x) {
int y=x;top=0;st[++top]=x;
while(nrt(y)) y=fa[y],st[++top]=y;
while(top) pushdown(st[top--]);
while(nrt(x)) {
int y=fa[x];
if(nrt(y)) rotate((ch[fa[y]][1]==y)^(ch[y][1]==x)?x:y);
rotate(x);
}
}
inline void access(int x) {
for(re int y=0;x;y=x,x=fa[x])
splay(x),ch[x][1]=y,pushup(x);
}
inline void mrt(int x) {
access(x);splay(x);rev[x]^=1;std::swap(ch[x][0],ch[x][1]);
}
inline void link(int x,int y) {
mrt(x);fa[x]=y;T.add(x,y);
}
inline void split(int x,int y) {
mrt(x);access(y);splay(y);
}
inline void ins(int x,int y,int v) {
split(x,y);v=dqm(v);work(y,v);
}
inline int query(int x,int y) {
split(x,y);
return dqm(sum[y]-a[y]);
}
}lct;
inline void add(int x,int y) {
e[++num].v=y;e[num].nxt=head[x];head[x]=num;
}
void dfs1(int x) {
sm[x]=1;
for(re int i=head[x];i;i=e[i].nxt) dfs1(e[i].v),sm[x]+=sm[e[i].v];
}
void dfs2(int x) {
ans=qm(ans+1ll*sm[x]*lct.query(d[x],1)%mod);
for(re int i=head[x];i;i=e[i].nxt) {
lct.ins(1,d[x],n-sm[e[i].v]-sm[x]);
dfs2(e[i].v);
lct.ins(1,d[x],sm[x]+sm[e[i].v]-n);
}
}
int main() {
n=read(),m=read();
for(re int x,i=1;i<=n;i++) {
x=read();if(x) add(x,i);else rt=i;
}
for(re int x,i=1;i<=m;i++) {
x=read();if(x) lct.link(x,i);
}
dfs1(rt);scanf("%s",S+1);T.pdfs(1);
for(re int i=1;i<=n;i++) {
d[i]=read();
ans=qm(ans+1ll*sm[i]*T.deep[d[i]]%mod*T.v[d[i]]%mod);
T.v[d[i]]=qm(T.v[d[i]]+sm[i]);
}
T.dfs(1,0);dfs2(rt);printf("%d
",ans);
return 0;
}