每次询问是关于 (x) 所在的连通块,所以考虑用点分树来解决本题。
点分树上每个节点所对应的子树,都是原树中的一个连通块。询问中给定 (x) 和区间 ([l,r]),其就已经确定了原树的一个连通块,所以可以在点分树上找到最大的一个子树包含该连通块,统计其内部合法点的个数即可。
首先处理出点分树上每个点在原树上到点分树根节点的链上所有节点路径经过节点编号的最小值和最大值。对于每个询问,在 (x) 到根节点的链上找到深度最浅的一个点,且原树上 (x) 到其路径经过节点编号的最小值和最大值在区间 ([l,r]) 内,这个节点所对应的子树就包含了该询问所对应的连通块,将询问挂到这个节点上。对于这个询问,(x) 和找到的这个节点是连通的,所以只需统计子树内有多少节点是和该节点连通,即其子树内有多少个点在原树上到该节点的路径经过节点编号的最小值和最大值在区间 ([l,r]) 内。
然后可以遍历点分树上每个点的子树来处理询问,因为点分树所有点的子树和为 (n log n) 级别,所以复杂度正确。
设一个节点到其点分树上子树的根节点的节点编号最小值为 (L),最大值为 (R),对于询问 ([l,r]),只有当 (l leqslant L,r geqslant R),且其颜色是第一次出现,该点才会对这个询问产生贡献。可以先对节点和询问的 (L) 进行排序,然后维护每种颜色的 (R) 的最小值,让最小值对询问产生贡献,用树状数组维护即可。
(code:)
#include<bits/stdc++.h>
#define maxn 200010
#define inf 1000000000
#define lowbit(x) (x&(-x))
using namespace std;
template<typename T> inline void read(T &x)
{
x=0;char c=getchar();bool flag=false;
while(!isdigit(c)){if(c=='-')flag=true;c=getchar();}
while(isdigit(c)){x=(x<<1)+(x<<3)+(c^48);c=getchar();}
if(flag)x=-x;
}
int n,m,tot,root;
int v[maxn],mi[maxn],ma[maxn],siz[maxn],ans[maxn],t[maxn];
bool vis[maxn];
struct node
{
int l,r,id,type;
};
bool cmp(const node &a,const node &b)
{
if(a.l==b.l) return a.type<b.type;
return a.l>b.l;
}
vector<node> ve[maxn],p[maxn];
struct edge
{
int to,nxt;
}e[maxn];
int head[maxn],edge_cnt;
void add(int from,int to)
{
e[++edge_cnt]=(edge){to,head[from]};
head[from]=edge_cnt;
}
void dfs_root(int x,int fath)
{
siz[x]=1,ma[x]=0;
for(int i=head[x];i;i=e[i].nxt)
{
int y=e[i].to;
if(vis[y]||y==fath) continue;
dfs_root(y,x),siz[x]+=siz[y];
ma[x]=max(ma[x],siz[y]);
}
ma[x]=max(ma[x],tot-siz[x]);
if(ma[x]<ma[root]) root=x;
}
void dfs_find(int x,int fath,int l,int r)
{
p[x].push_back((node){l,r,root,0});
ve[root].push_back((node){l,r,v[x],0});
for(int i=head[x];i;i=e[i].nxt)
{
int y=e[i].to;
if(vis[y]||y==fath) continue;
dfs_find(y,x,min(l,y),max(r,y));
}
}
void solve(int x)
{
int now=tot;
vis[x]=true,dfs_find(x,0,x,x);
for(int i=head[x];i;i=e[i].nxt)
{
int y=e[i].to;
if(vis[y]) continue;
root=0,tot=siz[y];
if(siz[y]>siz[x]) tot=now-siz[x];
dfs_root(y,x),solve(root);
}
}
void update(int x,int v)
{
while(x<=n) t[x]+=v,x+=lowbit(x);
}
int query(int x)
{
int v=0;
while(x) v+=t[x],x-=lowbit(x);
return v;
}
int main()
{
read(n),read(m);
for(int i=1;i<=n;++i) read(v[i]),mi[v[i]]=inf;
for(int i=1;i<n;++i)
{
int x,y;
read(x),read(y);
add(x,y),add(y,x);
}
tot=ma[0]=n,dfs_root(1,0),solve(root);
for(int i=1;i<=m;++i)
{
int l,r,x;
read(l),read(r),read(x);
for(int j=0;j<p[x].size();++j)
{
if(l<=p[x][j].l&&r>=p[x][j].r)
{
ve[p[x][j].id].push_back((node){l,r,i,1});
break;
}
}
}
for(int i=1;i<=n;++i)
{
sort(ve[i].begin(),ve[i].end(),cmp);
for(int j=0;j<ve[i].size();++j)
{
node x=ve[i][j];
if(x.type) ans[x.id]=query(x.r);
else if(x.r<mi[x.id])
update(mi[x.id],-1),update(x.r,1),mi[x.id]=x.r;
}
for(int j=0;j<ve[i].size();++j)
{
node x=ve[i][j];
if(!x.type&&mi[x.id]==x.r)
update(x.r,-1),mi[x.id]=inf;
}
}
for(int i=1;i<=m;++i) printf("%d
",ans[i]);
return 0;
}