题目描述
给定一棵n点的树,每个节点上有一个颜色,每次询问一个点的子树中与这个点距离不超过d的点的颜色有多少种。强制在线。
题解
线段树合并好题
刚刚学习了线段树合并就被忽悠来做这题
然后显然不会
所以就从网上搜了题解
一看这题首先想到主席树
但是主席树只能同时处理两种要求
可以先考虑如何算以点u为根的子树与点u的距离不超过d的点的个数
先对于每个节点建主席树,以深度为下标表示深度为i的节点个数
这玩意儿显然可以从子树的线段树上合并上来
然后再考虑如何算以点u为根的子树中颜色种数
也可以对于每个节点建主席树,以颜色为下标表示颜色为i的点在以u为根的子树内出现的最浅深度
然后这个东西也可以从子树的线段树中合并上来
题目要求是求出到u的距离小于等于d的子树的节点的颜色
所以要在最后答案满足可加性的话只能保留以u为根的子树内同一种颜色出现深度最小的一个
所以我们需要对第一棵线段树做一些修改
我们让第一棵线段树以深度为下标,下标i表示子树内有几种颜色出现的最小深度是i
这样信息就满足可加性了
那第一棵线段树怎么计算呢?
首先先按照以点u为根的子树与点u的距离不超过d的点个数的方式计算
然后在第二棵线段树合并点u的时候合并到最底层发现左右子树都出现了这种颜色,就要在第一棵线段树上减掉深度大的那个
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
const int M = 200005 ;
using namespace std ;
inline int read() {
char c = getchar() ; int x = 0 , w = 1 ;
while(c>'9'||c<'0') { if(c=='-') w = -1 ; c = getchar() ; }
while(c>='0'&&c<='9') { x = x*10+c-'0' ; c = getchar() ; }
return x*w ;
}
int n ;
int tot1 , tot2 , dep[M] ;
int col[M] , fa[M] , rt1[M] , rt2[M] ;
struct Node1 { int l , r , size ; } t1[M * 50] ;
struct Node2 { int l , r , Minpos ; } t2[M * 50] ;
inline void Clear() {
tot1 = tot2 = 0 ;
memset(rt1 , 0 , sizeof(rt1)) ; memset(rt2 , 0 , sizeof(rt2)) ;
}
void Insert1(int x , int v , int l , int r , int &now) {
t1[++tot1] = t1[now] ; now = tot1 ; t1[now].size += v ;
if(l == r) return ;
int mid = (l + r) >> 1 ;
if(mid >= x) Insert1(x , v , l , mid , t1[now].l) ;
else Insert1(x , v , mid + 1 , r , t1[now].r) ;
}
void Insert2(int x , int v , int l , int r , int &now) {
t2[++tot2] = t2[now] ; now = tot2 ; t2[now].Minpos = v ;
if(l == r) return ;
int mid = (l + r) >> 1 ;
if(mid >= x) Insert2(x , v , l , mid , t2[now].l) ;
else Insert2(x , v , mid + 1 , r , t2[now].r) ;
}
int Merge1(int x , int y , int l , int r) {
if(!x || !y) { return x + y ; }
int cnt = ++tot1 ; t1[cnt].size = t1[x].size + t1[y].size ;
if(l == r) return cnt ;
int mid = (l + r) >> 1 ;
t1[cnt].l = Merge1(t1[x].l , t1[y].l , l , mid) ;
t1[cnt].r = Merge1(t1[x].r , t1[y].r , mid + 1 , r) ;
return cnt ;
}
int Merge2(int x , int y , int l , int r , int id) {
if(!x || !y) return x + y ;
int cnt = ++tot2 ;
if(l == r) {
Insert1(max(t2[x].Minpos , t2[y].Minpos) , -1 , 1 , n , rt1[id]) ;
t2[cnt].Minpos = min(t2[x].Minpos , t2[y].Minpos) ;
return cnt ;
}
int mid = (l + r) >> 1 ;
t2[cnt].l = Merge2(t2[x].l , t2[y].l , l , mid , id) ;
t2[cnt].r = Merge2(t2[x].r , t2[y].r , mid + 1 , r , id) ;
return cnt ;
}
int query(int L , int R , int l , int r , int now) {
if(l > R || r < L) return 0 ;
if(l == L && r == R) return t1[now].size ;
int mid = (l + r) >> 1 ;
if(mid >= R) return query(L , R , l , mid , t1[now].l) ;
else if(mid < L) return query(L , R , mid + 1 , r , t1[now].r) ;
else return query(L , mid , l , mid , t1[now].l) + query(mid + 1 , R , mid + 1 , r , t1[now].r) ;
}
int main() {
int Case = read() ;
while(Case --) {
Clear() ;
n = read() ; int Q = read() ;
for(int i = 1 ; i <= n ; i ++) col[i] = read() ;
fa[1] = 1 ; dep[1] = 1 ;
for(int i = 2 ; i <= n ; i ++) {
fa[i] = read() ;
dep[i] = dep[fa[i]] + 1 ;
}
for(int i = n ; i >= 1 ; i --) {
Insert1(dep[i] , 1 , 1 , n , rt1[i]) ;
Insert2(col[i] , dep[i] , 1 , n , rt2[i]) ;
}
for(int i = n ; i >= 2 ; i --) {
rt1[fa[i]] = Merge1(rt1[fa[i]] , rt1[i] , 1 , n) ;
rt2[fa[i]] = Merge2(rt2[fa[i]] , rt2[i] , 1 , n , fa[i]) ;
}
int LastAns = 0 ;
while(Q --) {
int x = read() , d = read() ;
x ^= LastAns ; d ^= LastAns ;
LastAns = query(dep[x] , dep[x] + d , 1 , n , rt1[x]) ;
printf("%d
",LastAns) ;
}
}
return 0 ;
}