题意:给定一棵树 每个结点有一个颜色 然后又m个询问
询问:x d 问x的子树内不超过dep[x]+d 深度的子树结点一共有多少个颜色?
1、可以先将问题简化为问整个子树内有多少个不同的颜色
暴力解法树套树 但是可以用一个技巧来快速维护: 一个颜色一个颜色地处理 把所有相同颜色的点按照dfs序排序,每个点给自己的位置贡献1,相邻的两个点给lca贡献−1。然后只要区间内存在这种颜色,则其子树内的权值和必定为1。那么只需要这样子染好所有颜色之后询问子树和。
所以如果问题是这样的话只要一个线段树和一个树剖就可以完美的解决了!
2、回到当前的问题 正解是将线段树进化为主席树可持久化地维护即可 主席树的历史版本为深度!!! 所以查询得话就是 qsum(id[x],id[x]+siz[x]-1,1,n, T[ dep[x]+d] ) (范围为整个子树得编号 树为dep+d得深度得数)
#include<bits/stdc++.h> using namespace std; #define rep(i,a,b) for(int i=(a);i<=(b);i++) #define repp(i,a,b) for(int i=(a);i>=(b);--i) #define ll long long #define see(x) (cerr<<(#x)<<'='<<(x)<<endl) #define inf 0x3f3f3f3f #define CLR(A,v) memset(A,v,sizeof A) ////////////////////////////////// const int N=1e6+10; int T[N],t[N<<5],n,m,ncnt,lson[N<<5],rson[N<<5]; void upnode(int x,int v,int l,int r,int pre,int &pos) { pos=++ncnt; lson[pos]=lson[pre];rson[pos]=rson[pre]; t[pos]=t[pre]+v; int m=(l+r)>>1; if(l==r)return; if(x<=m)upnode(x,v,l,m,lson[pre],lson[pos]); else upnode(x,v,m+1,r,rson[pre],rson[pos]); } int qsum(int L,int R,int l,int r,int pos) { if(L<=l&&r<=R)return t[pos]; int ans=0,m=(l+r)>>1; if(L<=m)ans+=qsum(L,R,l,m,lson[pos]); if(R>m)ans+=qsum(L,R,m+1,r,rson[pos]); return ans; } ///////////////////////// int tot,id[N],fa[N],top[N],siz[N],son[N],c[N],dep[N],pos,head[N],a,pre[N]; struct Edge{int to,nex;}edge[N<<1]; void add(int a,int b){edge[++pos]=(Edge){b,head[a]};head[a]=pos;} void dfs1(int x,int f) { fa[x]=f;dep[x]=dep[f]+1;son[x]=0;siz[x]=1; for(int i=head[x];i;i=edge[i].nex) { int v=edge[i].to; if(v==f)continue; dfs1(v,x);siz[x]+=siz[v]; if(siz[v]>siz[son[x]])son[x]=v; } } void dfs2(int x,int topf) { top[x]=topf;id[x]=++tot;pre[tot]=x; if(son[x])dfs2(son[x],topf); for(int i=head[x];i;i=edge[i].nex) { int v=edge[i].to; if(v==fa[x]||v==son[x])continue; dfs2(v,v); } } int getlca(int x,int y) { while(top[x]!=top[y]) { if(dep[top[x]]<dep[top[y]])swap(x,y); x=fa[top[x]]; } return dep[x]>dep[y]?y:x; } /////////////////////////// set<int>s[N]; set<int>::iterator t1,t2,t3; int p[N]; bool cmp(int a,int b){return dep[a]<dep[b];} int main() { int cas;cin>>cas; while(cas--) { scanf("%d%d",&n,&m); rep(i,1,n)scanf("%d",&c[i]); rep(i,2,n)scanf("%d",&a),add(i,a),add(a,i); dfs1(1,0);dfs2(1,1); rep(i,1,n)p[i]=i; sort(p+1,p+1+n,cmp); for(int i=1,j=1;i<=dep[p[n]]&&j<=n;i++) { T[i]=T[i-1]; while(j<=n&&dep[p[j]]==i) { int x=p[j],idx=id[x]; s[c[x]].insert(id[x]); t2=s[c[x]].find(id[x]); t1=t2;t1--; t3=t2;t3++; upnode(id[x],1,1,n,T[i],T[i]); if(t2!=s[c[x]].begin())upnode(id[getlca(x,pre[*t1])],-1,1,n,T[i],T[i]);//,printf("1 lca=%d id=%d ",getlca(x,pre[*t1]),id[getlca(x,pre[*t1])]); if(t3!=s[c[x]].end())upnode(id[getlca(x,pre[*t3])],-1,1,n,T[i],T[i]);//,printf("2 lca=%d id=%d ",getlca(x,pre[*t3]),id[getlca(x,pre[*t3])]); if(t2!=s[c[x]].begin()&&t3!=s[c[x]].end()) upnode(id[getlca(pre[*t1],pre[*t3])],1,1,n,T[i],T[i]);//,printf("lca=%d id=%d ",getlca(pre[*t1],pre[*t3]),id[getlca(pre[*t1],pre[*t3])]); j++; } } int x,y,ans=0; while(m--) scanf("%d%d",&x,&y),x^=ans,y^=ans,printf("%d ",ans=qsum(id[x],id[x]+siz[x]-1,1,n,T[ min(dep[p[n]],dep[x]+y)])); CLR(head,0);pos=0; rep(i,1,n)s[c[i]].clear(); ncnt=tot=0; } return 0; }