题意:
给定(x),(d),求x子树里深度不超过dep[x]+d的所有点中有多少种颜色。
强制在线。
一般的,询问区间颜色数时,可以求出每个数的后继,然后就变成了区间内大于某数的数,进而使用树状数组或主席树。
然而,询问树上的颜色数,还有一种更好的方法:
考虑某种颜色的贡献:
把这种颜色的所有点到根的链,进行树链合并,即把这些点按照dfs序排序,每个点到根路径+1,相邻的lca到路径-1。
可以使用树上差分。
对于本题,把所有点按深度排序并依次加入,并处理出每个点按照DFS序排序的前面,后面第一个点。
由于强制在线,用主席树维护即可。
#include <stdio.h>
#include <vector>
#include <stdlib.h>
using namespace std;
#define MN 3800010
int fr[100010],ne[100010],v[100010],bs=0;
void addb(int a,int b)
{
v[bs]=b;
ne[bs]=fr[a];
fr[a]=bs++;
}
int fa[100010],sd[100010],son[100010],zl[100010],zr[100010],tm=0;
int dfs1(int u,int f)
{
fa[u]=f;sd[u]=sd[f]+1;
int he=1,ma=0;
son[u]=-1;zl[u]=tm++;
for(int i=fr[u];i!=-1;i=ne[i])
{
int rt=dfs1(v[i],u);
he+=rt;
if(rt>ma)
{
ma=rt;
son[u]=v[i];
}
}
zr[u]=tm;
return he;
}
int top[100010];
void dfs2(int u,int tp)
{
top[u]=tp;
if(son[u]==-1)
return;
dfs2(son[u],tp);
for(int i=fr[u];i!=-1;i=ne[i])
{
if(v[i]!=son[u])
dfs2(v[i],v[i]);
}
}
int getlca(int x,int y)
{
while(top[x]!=top[y])
{
if(sd[top[x]]>sd[top[y]])
x=fa[top[x]];
else
y=fa[top[y]];
}
if(sd[x]<sd[y])
return x;
return y;
}
int cl[MN],cr[MN],he[MN],sl=0;
int jianshu(int l,int r)
{
int rt=++sl;
he[rt]=cl[rt]=cr[rt]=0;
if(l+1==r)
return rt;
int m=(l+r)>>1;
cl[rt]=jianshu(l,m);
cr[rt]=jianshu(m,r);
return rt;
}
int add(int i,int l,int r,int j,int x)
{
int rt=++sl;
he[rt]=he[i]+x;
if(l+1==r)
return rt;
int m=(l+r)>>1;
cl[rt]=cl[i];cr[rt]=cr[i];
if(j<m)
cl[rt]=add(cl[i],l,m,j,x);
else
cr[rt]=add(cr[i],m,r,j,x);
return rt;
}
int sum(int i,int l,int r,int L,int R)
{
if(R<=l||r<=L)
return 0;
if(L<=l&&r<=R)
return he[i];
int m=(l+r)>>1;
return sum(cl[i],l,m,L,R)+sum(cr[i],m,r,L,R);
}
int tl[100010],tr[100010],sz[100010],wl[100010],wr[100010],px[100010],ro[100010];
int cmp(const void*a,const void*b)
{
return zl[*(int*)a]-zl[*(int*)b];
}
vector<int> co[100010],ve[100010];
int main()
{
int T;
scanf("%d",&T);
while(T--)
{
int n,m;
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++)
{
scanf("%d",&sz[i]);
co[sz[i]].push_back(i);
fr[i]=-1;
}
for(int i=2;i<=n;i++)
{
int a;
scanf("%d",&a);
addb(a,i);
}
dfs1(1,0);dfs2(1,1);
for(int i=1;i<=n;i++)
{
int k=co[i].size();
for(int j=0;j<k;j++)
px[j]=co[i][j];
qsort(px,k,sizeof(int),cmp);
for(int j=0;j<k;j++)
{
tl[j]=tr[j]=j;
while(tl[j]>0&&sd[px[tl[j]-1]]>sd[px[j]])
tl[j]=tl[tl[j]-1];
}
for(int j=k-1;j>=0;j--)
{
while(tr[j]+1<k&&sd[px[tr[j]+1]]>=sd[px[j]])
tr[j]=tr[tr[j]+1];
}
for(int j=0;j<k;j++)
{
int u=px[j];
if(tl[j]==0)wl[u]=-1;
else wl[u]=px[tl[j]-1];
if(tr[j]==k-1)wr[u]=-1;
else wr[u]=px[tr[j]+1];
}
}
for(int i=1;i<=n;i++)
ve[sd[i]].push_back(i);
ro[0]=jianshu(0,tm);
for(int i=1;i<=n;i++)
{
ro[i]=ro[i-1];
int k=ve[i].size();
for(int j=0;j<k;j++)
px[j]=ve[i][j];
qsort(px,k,sizeof(int),cmp);
for(int j=0;j<k;j++)
{
int u=px[j];
ro[i]=add(ro[i],0,tm,zl[u],1);
if(wl[u]!=-1&&wr[u]!=-1)
ro[i]=add(ro[i],0,tm,zl[getlca(wl[u],wr[u])],1);
if(wl[u]!=-1)
ro[i]=add(ro[i],0,tm,zl[getlca(u,wl[u])],-1);
if(wr[u]!=-1)
ro[i]=add(ro[i],0,tm,zl[getlca(u,wr[u])],-1);
}
}
for(int i=0,la=0;i<m;i++)
{
int x,d;
scanf("%d%d",&x,&d);
x^=la;d^=la;
int z=sd[x]+d;
if(z>n)z=n;
la=sum(ro[z],0,tm,zl[x],zr[x]);
printf("%d
",la);
}
bs=tm=sl=0;
for(int i=1;i<=n;i++)
{
ve[i].clear();
co[i].clear();
}
}
return 0;
}