首先考虑判断两个点 (a,b) 在一个询问 ((l,r,x)) 下连通:(path(a,b)) 上的点的编号都在 (l sim r) 之间,可以用倍增预处理。
对原树建立点分树。对于每个询问 ((l,r,x)),我们考虑找到这样一个点:
- 这个点是 (x) 在点分树上的祖先
- 这个点与 (x) 连通
- 满足以上两个条件的情况下,要求这个点的深度尽可能小
设找到的这个点为 (y)。因为在原树上的连通块放到点分树上后他们的LCA一定在这个连通块内,那么所有与 (x) 连通的点都必然在 (y) 的子树(点分树)内。我们把这个询问离线下来插进 (y) 里。然后这个询问就和 (x) 没有关系了。
然后对于点分树上每个点 (x),对于它的子树上的一个点 (p),设 (path(x,p)) 上编号最大的点为 (mx),编号最小的点为 (mn),那么他们连通的条件就是 (mngeq l) 且 (mxleq r)。那么和我们像 HH的项链 这题一样处理二维偏序和颜色种类就好了。
代码:
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<iostream>
#include<vector>
#include<ctime>
using namespace std;
const int N=100009;
int head[N],cnt,n,q,ans[N],F[N],root,col[N];
struct Edge
{
int nxt,to;
}g[N*2];
struct Tree
{
int head[N],cnt,del[N],siz[N];
int dep[N],f[N][30],fmin[N][30],fmax[N][30];
struct Edge
{
int nxt,to;
}g[N*2];
void add(int from,int to)
{
g[++cnt].nxt=head[from];
g[cnt].to=to;
head[from]=cnt;
}
void dfs(int x,int fa)
{
siz[x]=1;
for (int i=head[x];i;i=g[i].nxt)
{
int v=g[i].to;
if(v==fa||del[v])
continue;
dfs(v,x);
siz[x]+=siz[v];
}
}
int Get_Weight(int x)
{
dfs(x,-1);
int k=siz[x]/2,fa=-1;
while(1)
{
int tmp=0;
for (int i=head[x];i;i=g[i].nxt)
{
int v=g[i].to;
if(v==fa||del[v])
continue;
if(siz[tmp]<siz[v])
tmp=v;
}
if(siz[tmp]<=k)
return x;
fa=x,x=tmp;
}
}
void DFS(int x,int fa)
{
fmin[x][0]=fmax[x][0]=x;
for (int i=head[x];i;i=g[i].nxt)
{
int v=g[i].to;
if(v==fa)
continue;
dep[v]=dep[x]+1;
f[v][0]=x;
DFS(v,x);
}
}
pair<int,int> Get_Min_Max(int x,int y)
{
int Min=1<<30,Max=0;
if(dep[x]!=dep[y])
{
if(dep[x]<dep[y])
swap(x,y);
int k=dep[x]-dep[y];
for (int i=20;i>=0;i--)
if(k>=1<<i)
Min=min(Min,fmin[x][i]),
Max=max(Max,fmax[x][i]),
x=f[x][i],k-=1<<i;
}
if(x==y)
return make_pair(min(Min,x),max(Max,x));
for (int i=20;i>=0;i--)
if(f[x][i]!=f[y][i])
Min=min(Min,min(fmin[y][i],fmin[x][i])),
Max=max(Max,max(fmax[y][i],fmax[x][i])),
x=f[x][i],y=f[y][i];
Min=min(Min,min(fmin[y][0],fmin[x][0]));
Max=max(Max,max(fmax[y][0],fmax[x][0]));
return make_pair(min(Min,f[x][0]),max(f[x][0],Max));
}
void work()
{
memset(fmin,0x3f,sizeof(fmin));
add(0,1);
DFS(0,-1);
for (int j=1;j<=20;j++)
for (int i=1;i<=n;i++)
f[i][j]=f[f[i][j-1]][j-1],
fmin[i][j]=min(fmin[i][j-1],fmin[f[i][j-1]][j-1]),
fmax[i][j]=max(fmax[i][j-1],fmax[f[i][j-1]][j-1]);
}
}T;
struct Question
{
int id,l,r;
bool operator < (const Question A)const
{
return r<A.r;
}
};
struct BIT
{
#define MaxN 100000
int c[N];
int lowbit(int x)
{
return x&-x;
}
void Modify(int x,int k)
{
while(x<=MaxN)
c[x]+=k,x+=lowbit(x);
}
int Query(int x)
{
int ans=0;
while(x)
ans+=c[x],x-=lowbit(x);
return ans;
}
}A;
struct Bin
{
int l,r,col;
bool operator < (const Bin A)const
{
return r<A.r;
}
};
Bin bin[N];
vector <Question> b[N];
void add(int from,int to)
{
g[++cnt].nxt=head[from];
g[cnt].to=to;
head[from]=cnt;
}
void init()
{
scanf("%d %d",&n,&q);
for (int i=1;i<=n;i++)
scanf("%d",&col[i]);
for (int i=1,x,y;i<n;i++)
scanf("%d %d",&x,&y),
T.add(x,y),T.add(y,x);
T.work();
}
void build(int x,int fa)
{
int w=T.Get_Weight(x);
T.del[w]=1;
if(fa!=-1)
add(w,fa),add(fa,w),F[w]=fa;
else
root=w;
for (int i=T.head[w];i;i=T.g[i].nxt)
{
int v=T.g[i].to;
if(T.del[v]) continue;
build(v,w);
}
}
int Cnt;
void DFS(int x,int fa,int X)
{
pair<int,int> G=T.Get_Min_Max(X,x);
bin[++Cnt]=(Bin){G.first,G.second,col[x]};
for (int i=head[x];i;i=g[i].nxt)
{
int v=g[i].to;
if(v==fa)
continue;
DFS(v,x,X);
}
}
int Maxc[N];
void dfs(int x,int fa)
{
Cnt=0,DFS(x,fa,x);
int now=1;
sort(b[x].begin(),b[x].end());
sort(bin+1,bin+1+Cnt);
for (int i=0;i<b[x].size();i++)
{
Question v=b[x][i];
while(bin[now].r<=v.r&&now<=Cnt)
{
if(Maxc[bin[now].col]==0)
A.Modify(bin[now].l,1),
Maxc[bin[now].col]=bin[now].l;
else if(Maxc[bin[now].col]<bin[now].l)
A.Modify(bin[now].l,1),
A.Modify(Maxc[bin[now].col],-1),
Maxc[bin[now].col]=bin[now].l;
now++;
}
// printf("%d %d
",v.r,v.l-1);
ans[v.id]=A.Query(v.r)-A.Query(v.l-1);
}
for (int i=1;i<=Cnt;i++)
if(Maxc[bin[i].col]!=0)
A.Modify(Maxc[bin[i].col],-1),
Maxc[bin[i].col]=0;
for (int i=head[x];i;i=g[i].nxt)
{
int v=g[i].to;
if(v==fa)
continue;
dfs(v,x);
}
}
void print(int x,int fa)
{
for (int i=head[x];i;i=g[i].nxt)
{
int v=g[i].to;
if(v==fa)
continue;
printf("%d %d
",x,v);
print(v,x);
}
}
void work()
{
build(1,-1);
// print(root,-1);
/*srand(time(0));
puts("wyj!!!");
for (int i=1;i<=10;i++)
{
int x=rand()%n+1,y=rand()%n+1;
if(x>y) swap(x,y);
pair<int,int> G=T.Get_Min_Max(x,y);
printf("%d %d %d %d
",x,y,G.first,G.second);
}*/
// pair<int,int> G=T.Get_Min_Max(5,10);
// printf("%d %d
",G.first,G.second);
for (int i=1;i<=q;i++)
{
int x,y,z;
scanf("%d %d %d",&x,&y,&z);
int pos=z;
for (int j=F[z];j;j=F[j])
{
pair<int,int> G=T.Get_Min_Max(j,z);
if(G.first>=x&&G.second<=y) pos=j;
}
b[pos].push_back((Question){i,x,y});
}
dfs(root,-1);
for (int i=1;i<=q;i++)
printf("%d
",ans[i]);
}
int main()
{
init();
work();
return 0;
}