树上莫队模板题。
使用欧拉序将树上路径转化为普通区间。
之后莫队维护即可。不要忘记特判LCA
1 #include<iostream> 2 #include<cstdio> 3 #include<cstring> 4 #include<cmath> 5 #include<algorithm> 6 #define N 200005 7 using namespace std; 8 int read() 9 { 10 int x=0,f=1;char ch=getchar(); 11 while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();} 12 while(ch>='0'&&ch<='9'){x=(x<<3)+(x<<1)+(ch^48);ch=getchar();} 13 return x*f; 14 } 15 int n,m,val[N],inp[N],cnt[N],v[N],nxt[N],head[N],cntt,tot,siz,ord[N],ncnt,fir[N],lst[N]; 16 int fa[N][25],dep[N],now,l=1,r=0,ans[N],vis[N]; 17 struct node 18 { 19 int l,r,x,lca; 20 }q[N]; 21 bool cmp(node a,node b) 22 { 23 if(a.l/siz!=b.l/siz)return a.l/siz<b.l/siz; 24 return a.r<b.r; 25 } 26 void add(int a,int b) 27 { 28 v[++cntt]=b; 29 nxt[cntt]=head[a]; 30 head[a]=cntt; 31 } 32 void dfs1(int x,int f) 33 { 34 dep[x]=dep[f]+1; 35 for(int i=0;i<=19;i++)fa[x][i+1]=fa[fa[x][i]][i]; 36 for(int i=head[x];i;i=nxt[i]) 37 { 38 if(v[i]==f)continue; 39 fa[v[i]][0]=x; 40 dfs1(v[i],x); 41 } 42 } 43 int lca(int x,int y) 44 { 45 if(dep[x]<dep[y])swap(x,y); 46 for(int i=20;i>=0;i--) 47 { 48 if(dep[fa[x][i]]>=dep[y])x=fa[x][i]; 49 if(x==y)return x; 50 } 51 for(int i=20;i>=0;i--) 52 { 53 if(fa[x][i]!=fa[y][i])x=fa[x][i],y=fa[y][i]; 54 } 55 return fa[x][0]; 56 } 57 void dfs2(int x,int f) 58 { 59 ord[++ncnt]=x; 60 fir[x]=ncnt; 61 for(int i=head[x];i;i=nxt[i]) 62 { 63 if(v[i]==f)continue; 64 dfs2(v[i],x); 65 } 66 ord[++ncnt]=x; 67 lst[x]=ncnt; 68 } 69 void work(int pos) 70 { 71 if(vis[pos])now-=!--cnt[val[pos]]; 72 else now+=!cnt[val[pos]]++; 73 vis[pos]^=1; 74 } 75 int main() 76 { 77 n=read();m=read(); 78 for(int i=1;i<=n;i++)val[i]=inp[i]=read(); 79 sort(inp+1,inp+n+1); 80 tot=unique(inp+1,inp+n+1)-inp-1; 81 for(int i=1;i<=n;i++)val[i]=lower_bound(inp+1,inp+tot+1,val[i])-inp; 82 for(int x,y,i=1;i<n;i++) 83 { 84 x=read();y=read(); 85 add(x,y);add(y,x); 86 } 87 dfs1(1,0);dfs2(1,0); 88 for(int i=1;i<=m;i++) 89 { 90 int ll=read(),rr=read(),LCA=lca(ll,rr); 91 if(fir[ll]>fir[rr])swap(ll,rr); 92 if(ll==LCA) 93 { 94 q[i].l=fir[ll]; 95 q[i].r=fir[rr]; 96 } 97 else 98 { 99 q[i].l=lst[ll]; 100 q[i].r=fir[rr]; 101 q[i].lca=LCA; 102 } 103 q[i].x=i; 104 } 105 siz=sqrt(ncnt); 106 sort(q+1,q+m+1,cmp); 107 for(int i=1;i<=m;i++) 108 { 109 int ll=q[i].l,rr=q[i].r,Lca=q[i].lca; 110 while(l<ll)work(ord[l++]); 111 while(l>ll)work(ord[--l]); 112 while(r<rr)work(ord[++r]); 113 while(r>rr)work(ord[r--]); 114 if(Lca)work(Lca); 115 ans[q[i].x]=now; 116 if(Lca)work(Lca); 117 } 118 for(int i=1;i<=m;i++)printf("%d ",ans[i]); 119 return 0; 120 }