2243: [SDOI2011]染色
Time Limit: 20 Sec Memory Limit: 512 MBSubmit: 9031 Solved: 3388
[Submit][Status][Discuss]
Description
给定一棵有n个节点的无根树和m个操作,操作有2类:
1、将节点a到节点b路径上所有点都染成颜色c;
2、询问节点a到节点b路径上的颜色段数量(连续相同颜色被认为是同一段),
如“112221”由3段组成:“11”、“222”和“1”。
请你写一个程序依次完成这m个操作。
Input
第一行包含2个整数n和m,分别表示节点数和操作数;
第二行包含n个正整数表示n个节点的初始颜色
下面 行每行包含两个整数x和y,表示x和y之间有一条无向边。
下面 行每行描述一个操作:
“C a b c”表示这是一个染色操作,把节点a到节点b路径上所有点(包括a和b)都染成颜色c;
“Q a b”表示这是一个询问操作,询问节点a到节点b(包括a和b)路径上的颜色段数量。
Output
对于每个询问操作,输出一行答案。
Sample Input
6 5
2 2 1 2 1 1
1 2
1 3
2 4
2 5
2 6
Q 3 5
C 2 1 1
Q 3 5
C 5 1 2
Q 3 5
2 2 1 2 1 1
1 2
1 3
2 4
2 5
2 6
Q 3 5
C 2 1 1
Q 3 5
C 5 1 2
Q 3 5
Sample Output
3
1
2
1
2
HINT
数N<=10^5,操作数M<=10^5,所有的颜色C为整数且在[0, 10^9]之间。
Source
分析:一眼树链剖分+线段树题,主要是边界和细节地方的处理.
线段树记录一下左端点和右端点的颜色,每次合并的时候要看两个端点的颜色是否一致.如果一致那么答案要-1.树链剖分求答案的时候要记录一下当前链头的颜色和链尾的颜色,判断是否相等.在线段树query的时候也要判断一下两端是否有相同颜色.需要注意的是要添加两个bool变量,记录是否跨过两个区间.
#include <cstdio> #include <cstring> #include <iostream> #include <algorithm> using namespace std; const int maxn = 100010; int sum[maxn << 2],lc[maxn <<2],rc[maxn << 2],tag[maxn << 2],dep[maxn],son[maxn],top[maxn],sizee[maxn],fa[maxn],id[maxn],idx[maxn],cnt; int head[maxn],to[maxn * 2],nextt[maxn * 2],tot = 1,v[maxn]; int n,m,lcol,rcol,ans1,ans2,ans,L,R; char s[10]; void add(int x,int y) { to[tot] = y; nextt[tot] = head[x]; head[x] = tot++; } void dfs(int u,int d,int from) { fa[u] = from; dep[u] = d; sizee[u] = 1; for (int i = head[u]; i; i = nextt[i]) { int v = to[i]; if (v == from) continue; dfs(v,d + 1,u); sizee[u] += sizee[v]; if(sizee[v] >= sizee[son[u]]) son[u] = v; } } void dfs2(int u,int topp) { top[u] = topp; id[u] = ++cnt; idx[cnt] = u; if (son[u]) dfs2(son[u],topp); for (int i = head[u]; i; i = nextt[i]) { int v = to[i]; if(v == fa[u] || v == son[u]) continue; dfs2(v,v); } } void pushup(int o) { lc[o] = lc[o * 2]; rc[o] = rc[o * 2 + 1]; sum[o] = sum[o * 2] + sum[o * 2 + 1]; if (rc[o * 2] == lc[o * 2 + 1]) sum[o]--; }void pushdown(int o) { if (tag[o]) { tag[o * 2] = tag[o * 2 + 1] = 1; tag[o] = 0; lc[o * 2] = rc[o * 2] = lc[o * 2 + 1] = rc[o * 2 + 1] = lc[o]; sum[o * 2] = sum[o * 2 + 1] = 1; } } void update(int o,int l,int r,int x,int y,int v) { if (x <= l && r <= y) { tag[o] = 1; lc[o] = rc[o] = v; sum[o] = 1; return; } pushdown(o); int mid = (l + r) >> 1; if (x <= mid) update(o * 2,l,mid,x,y,v); if (y > mid) update(o * 2 + 1,mid + 1,r,x,y,v); pushup(o); } int query(int o,int l,int r,int x,int y) { if (l == L) lcol = lc[o]; if (r == R) rcol = rc[o]; if (x <= l && r <= y) return sum[o]; pushdown(o); int mid = (l + r) >> 1,res = 0; bool flag1 = false,flag2 = false; if (x <= mid) { res += query(o * 2,l,mid,x,y); flag1 = true; } if (y > mid) { res += query(o * 2 + 1,mid + 1,r,x,y); flag2 = true; } if (flag1 && flag2 && rc[o * 2] == lc[o * 2 + 1]) res--; return res; } void solve2(int x,int y,int z) { if (dep[x] < dep[y]) swap(x,y); while (top[x] != top[y]) { if (dep[top[x]] < dep[top[y]]) swap(x,y); update(1,1,n,id[top[x]],id[x],z); x = fa[top[x]]; } if (dep[x] < dep[y]) swap(x,y); update(1,1,n,id[y],id[x],z); } int solve1(int x,int y) { ans1 = ans2 = -1; ans = 0; while (top[x] != top[y]) { if (dep[top[x]] < dep[top[y]]) { swap(x,y); swap(ans1,ans2); } L = id[top[x]],R = id[x]; ans += query(1,1,n,id[top[x]],id[x]); if (rcol == ans1) ans--; ans1 = lcol; x = fa[top[x]]; } if (dep[x] < dep[y]) swap(x,y),swap(ans1,ans2); L = id[y],R = id[x]; ans += query(1,1,n,id[y],id[x]); if (rcol == ans1) ans--; if (lcol == ans2) ans--; return ans; } int main() { scanf("%d%d",&n,&m); for (int i = 1; i <= n; i++) scanf("%d",&v[i]); for (int i = 1; i < n; i++) { int u,v; scanf("%d%d",&u,&v); add(u,v); add(v,u); } dfs(1,1,0); dfs2(1,1); for (int i = 1; i <= n; i++) update(1,1,n,id[i],id[i],v[i]); for (int i = 1; i <= m; i++) { scanf("%s",s); int a,b,c; if(s[0] == 'Q') { scanf("%d%d",&a,&b); printf("%d ",solve1(a,b)); } else { scanf("%d%d%d",&a,&b,&c); solve2(a,b,c); } } return 0; }