https://www.lydsy.com/JudgeOnline/problem.php?id=2243
新学的树剖,在维护的时候线段树维护区间内颜色数量以及左右两端的颜色。统计的时候区间合并时判断中间的合并点颜色是否相等,相等则ans -1
在计算答案的时候不同的链上两端跳的过程中也要注意color[top[u]]和color[fa[top[u]]]是否相等,相等则ans--
#include <map> #include <set> #include <ctime> #include <cmath> #include <queue> #include <stack> #include <vector> #include <string> #include <cstdio> #include <cstdlib> #include <cstring> #include <sstream> #include <iostream> #include <algorithm> #include <functional> using namespace std; inline int read(){int now=0;register char c=getchar();for(;!isdigit(c);c=getchar()); for(;isdigit(c);now=now*10+c-'0',c=getchar());return now;} #define For(i, x, y) for(int i=x;i<=y;i++) #define _For(i, x, y) for(int i=x;i>=y;i--) #define Mem(f, x) memset(f,x,sizeof(f)) #define Sca(x) scanf("%d", &x) #define Sca2(x,y) scanf("%d%d",&x,&y) #define Sca3(x,y,z) scanf("%d%d%d",&x,&y,&z) #define Scl(x) scanf("%lld",&x); #define Pri(x) printf("%d ", x) #define Prl(x) printf("%lld ",x); #define CLR(u) for(int i=0;i<=N;i++)u[i].clear(); #define LL long long #define ULL unsigned long long #define mp make_pair #define PII pair<int,int> #define PIL pair<int,long long> #define PLL pair<long long,long long> #define pb push_back #define fi first #define se second typedef vector<int> VI; const double eps = 1e-9; const int maxn = 1e5 + 10; const int INF = 0x3f3f3f3f; const int mod = 1e9 + 7; int N,M,K; int head[maxn],tot; struct Edge{ int to,next; }edge[maxn * 2]; void init(){ for(int i = 1; i <= N ; i ++) head[i] = -1; tot = 0; } void add(int u,int v){ edge[tot].to = v; edge[tot].next = head[u]; head[u] = tot++; } //树的信息 int dep[maxn],hson[maxn],fa[maxn],Size[maxn],color[maxn],Index[maxn],top[maxn]; int nw[maxn]; void dfs1(int t,int la){ Size[t] = 1; hson[t] = 0; int MAX = 0; for(int i = head[t]; ~i; i = edge[i].next){ int v = edge[i].to; if(v == la) continue; dep[v] = dep[t] + 1; fa[v] = t; dfs1(v,t); Size[t] += Size[v]; if(Size[v] > MAX){ MAX = Size[v]; hson[t] = v; } } } int cnt = 0; void dfs2(int t,int la){ top[t] = la; Index[t] = ++cnt; nw[cnt] = color[t]; if(hson[t]) dfs2(hson[t],la); for(int i = head[t]; ~i; i = edge[i].next){ int v = edge[i].to; if(v == hson[t] || v == fa[t]) continue; dfs2(v,v); } } //线段树 struct Tree{ int l,r; int sum; int lazy; int lc,rc; }tree[maxn << 2]; void Pushup(int t){ tree[t].sum = tree[t << 1].sum + tree[t << 1 | 1].sum; tree[t].lc = tree[t << 1].lc; tree[t].rc = tree[t << 1 | 1].rc; if(tree[t << 1].rc == tree[t << 1 | 1].lc) tree[t].sum--; } void Build(int t,int l,int r){ tree[t].lazy = 0; tree[t].l = l; tree[t].r = r; if(tree[t].l == tree[t].r){ tree[t].sum = 1; tree[t].lc = tree[t].rc = nw[l]; return; } int m = (l + r) >> 1; Build(t << 1,l,m); Build(t << 1 | 1,m + 1,r); Pushup(t); } void Pushdown(int t){ if(tree[t].lazy){ tree[t << 1].lazy = tree[t << 1 | 1].lazy = tree[t].lazy; tree[t << 1].sum = tree[t << 1 | 1].sum = 1; tree[t << 1].lc = tree[t << 1 | 1].lc = tree[t].lazy; tree[t << 1].rc = tree[t << 1 | 1].rc = tree[t].lazy; tree[t].lazy = 0; } } void update(int t,int l,int r,int c){ if(l <= tree[t].l && tree[t].r <= r){ tree[t].lazy = c; tree[t].sum = 1; tree[t].lc = tree[t].rc = c; return; } Pushdown(t); int m = (tree[t].l + tree[t].r) >> 1; if(r <= m) update(t << 1,l,r,c); else if(l > m) update(t << 1 | 1,l,r,c); else{ update(t << 1,l,m,c); update(t << 1 | 1,m + 1,r,c); } Pushup(t); } int query(int t,int l,int r){ if(l <= tree[t].l && tree[t].r <= r) return tree[t].sum; Pushdown(t); int ans = 0; int m = (tree[t].l + tree[t].r) >> 1; if(r <= m) ans = query(t << 1,l,r); else if(l > m) ans = query(t << 1 | 1,l,r); else{ ans = query(t << 1,l,m) + query(t << 1 | 1,m + 1,r); if(tree[t << 1].rc == tree[t << 1 | 1].lc) ans--; } Pushup(t); return ans; } int query(int t,int p){ if(tree[t].sum == 1) return tree[t].lc; int m = (tree[t].l + tree[t].r) >> 1; if(p <= m) return query(t << 1,p); else if(p > m) return query(t << 1 | 1,p); } //树链剖分 void update(int u,int v,int c){ while(top[u] != top[v]){ if(dep[top[u]] < dep[top[v]]) swap(u,v); update(1,Index[top[u]],Index[u],c); u = fa[top[u]]; } if(dep[u] > dep[v]) swap(u,v); update(1,Index[u],Index[v],c); } int Query(int u,int v){ int ans = 0; int la = 0; while(top[u] != top[v]){ if(dep[top[u]] < dep[top[v]]) swap(u,v); ans += query(1,Index[top[u]],Index[u]); if(query(1,Index[fa[top[u]]]) == query(1,Index[top[u]])) ans--; u = fa[top[u]]; } if(dep[u] > dep[v]) swap(u,v); ans += query(1,Index[u],Index[v]); return ans; } int main(){ Sca2(N,M); init(); for(int i = 1; i <= N ; i ++) color[i] = read();; for(int i = 1; i < N ; i ++){ int u = read(),v = read(); add(u,v); add(v,u); } int root = N; dep[root] = 0; dfs1(root,-1); cnt = 0; dfs2(root,root); Build(1,1,N); while(M--){ char op[3]; int a,b,c; scanf("%s",op); a = read(); b = read(); if(op[0] == 'C'){ c = read(); update(a,b,c); }else{ Pri(Query(a,b)); } } return 0; }