【题目描述】
给定一棵有n个节点的无根树和m个操作,操作有2类:
1、将节点a到节点b路径上所有点都染成颜色c;
2、询问节点a到节点b路径上的颜色段数量(连续相同颜色被认为是同一段),如“112221”由3段组成:“11”、“222”和“1”。
请你写一个程序依次完成这m个操作。
【输入格式】
第一行包含2个整数n和m,分别表示节点数和操作数;
第二行包含n个正整数表示n个节点的初始颜色
下面n-1行每行包含两个整数x和y,表示x和y之间有一条无向边。
下面m行每行描述一个操作:
“C a b c”表示这是一个染色操作,把节点a到节点b路径上所有点(包括a和b)都染成颜色c;
“Q a b”表示这是一个询问操作,询问节点a到节点b(包括a和b)路径上的颜色段数量。
【输出格式】
对于每个询问操作,输出一行答案。
题解
看到这种树上操作路径的题大概率就是树链剖分了
先考虑一下 如果是一条链要怎么做
显然 可以用线段树维护 每个区间记录三个值:此区间内的颜色段数量,区间最左边点的颜色,区间最右边点的颜色
那么 一个区间的颜色段数量=左区间颜色段数量+右区间颜色段数量-(左区间最右颜色==右区间最左颜色);最左边点颜色=左区间最左点颜色 最右边点颜色=右区间最右边点颜色
查询也是差不多 把左儿子返回的结果(这里可以返回一个结构体)和右儿子返回的结果按上一行的方法合并一下即可
然而树剖查询一条路径时是有最多log次的线段树区间查询 而不是直接查询一个区间 怎么合并这些查询答案统计出最终答案呢
其实只需要记录一下上一次查询区间的左端点颜色 如果和这一次查询区间的右端点颜色相同 答案-1即可
具体可见代码 统计答案这段还是有点抽搐
ps: 这题线段树区间修改要打懒标记
【代码】
#include <bits/stdc++.h>
using namespace std;
int n, m, a[100005], b1, b2, b3, lst[10];
int head[100005], pre[200005], to[200005], sz;
int fa[100005], d[100005], st[100005], dfn[100005], son[100005], pos[100005], siz[100005], tme;
char tp[5];
struct segtree{
int l, r, lcol, rcol, cnt, tag;
segtree() {
lcol = rcol = tag = -1; l = r = cnt = 0;
}
} tr[400005];
inline void addedge(int u, int v) {
pre[++sz] = head[u]; head[u] = sz; to[sz] = v;
}
namespace Segtree{
inline void pushup(int ind) {
tr[ind].cnt = tr[ind<<1].cnt + tr[ind<<1|1].cnt - (tr[ind<<1].rcol == tr[ind<<1|1].lcol);
tr[ind].lcol = tr[ind<<1].lcol; tr[ind].rcol = tr[ind<<1|1].rcol;
}
inline void pushdown(int ind) {
if (tr[ind].tag == -1) return;
tr[ind<<1].cnt = tr[ind<<1|1].cnt = 1;
tr[ind<<1].lcol = tr[ind<<1|1].lcol = tr[ind<<1].rcol = tr[ind<<1|1].rcol = tr[ind].tag;
tr[ind<<1].tag = tr[ind<<1|1].tag = tr[ind].tag; tr[ind].tag = -1;
}
void build(int ind, int l, int r) {
tr[ind] = segtree();
tr[ind].l = l; tr[ind].r = r; tr[ind].tag = -1;
if (l == r) {
tr[ind].cnt = 1; tr[ind].lcol = tr[ind].rcol = a[pos[l]];
return;
}
int mid = (l + r) >> 1;
build(ind<<1, l, mid); build(ind<<1|1, mid+1, r);
pushup(ind);
}
void update(int ind, int x, int y, int v) {
int l = tr[ind].l, r = tr[ind].r;
if (x <= l && r <= y) {
tr[ind].cnt = 1; tr[ind].lcol = tr[ind].rcol = v; tr[ind].tag = v; return;
}
int mid = (l + r) >> 1;
pushdown(ind);
if (x <= mid) update(ind<<1, x, y, v);
if (mid < y) update(ind<<1|1, x, y, v);
pushup(ind);
}
inline segtree merge(segtree a, segtree b) {
segtree ret = segtree();
ret.cnt = a.cnt + b.cnt - (a.rcol == b.lcol);
ret.lcol = a.lcol; ret.rcol = b.rcol;
return ret;
}
segtree query(int ind, int x, int y) {
int l = tr[ind].l, r = tr[ind].r;
if (x <= l && r <= y) {
return tr[ind];
}
int mid = (l + r) >> 1;
pushdown(ind);
segtree ret1 = segtree(), ret2 = segtree();
if (x <= mid) ret1 = query(ind<<1, x, y);
if (mid < y) ret2 = query(ind<<1|1, x, y);
if (!ret1.cnt) return ret2;
else if (!ret2.cnt) return ret1;
else return merge(ret1, ret2);
}
}
using namespace Segtree;
namespace treechains{
void dfs1(int x, int f) {
siz[x] = 1;
for (int i = head[x]; i; i = pre[i]) {
int y = to[i];
if (y == f) continue;
d[y] = d[x] + 1; fa[y] = x;
dfs1(y, x); siz[x] += siz[y];
if (siz[y] > siz[son[x]]) son[x] = y;
}
}
void dfs2(int x, int start) {
dfn[x] = ++tme; pos[tme] = x; st[x] = start;
if (son[x]) dfs2(son[x], start);
for (int i = head[x]; i; i = pre[i]) {
int y = to[i];
if (y != fa[x] && y != son[x]) dfs2(y, y);
}
}
void change(int x, int y, int z) {
while (st[x] != st[y]) {
if (d[st[x]] < d[st[y]]) swap(x, y);
update(1, dfn[st[x]], dfn[x], z);
x = fa[st[x]];
}
if (dfn[x] > dfn[y]) swap(x, y);
update(1, dfn[x], dfn[y], z);
}
int ask(int x, int y) {
int ret = 0, o = 0; lst[0] = lst[1] = -1;
//lst[0]: x~lca这条链上上一次查询的左端点颜色; lst[1]: y~lca这条链上上一次查询的左端点颜色
segtree tmp = segtree();
while (st[x] != st[y]) {
if (d[st[x]] < d[st[y]]) swap(x, y), o ^= 1;
tmp = query(1, dfn[st[x]], dfn[x]);
ret += tmp.cnt - (tmp.rcol == lst[o]); lst[o] = tmp.lcol;
x = fa[st[x]];
}
if (dfn[x] > dfn[y]) swap(x, y), o ^= 1;
tmp = query(1, dfn[x], dfn[y]);
if (!o) {
ret += tmp.cnt - (lst[0] == tmp.lcol) - (lst[1] == tmp.rcol);
} else {
ret += tmp.cnt - (lst[1] == tmp.lcol) - (lst[0] == tmp.rcol);
}
return ret;
}
}
using namespace treechains;
int main() {
scanf("%d %d", &n, &m);
for (int i = 1; i <= n; i++) scanf("%d", &a[i]);
for (int i = 1; i < n; i++) {
int x, y; scanf("%d %d", &x, &y);
addedge(x, y); addedge(y, x);
}
dfs1(1, 0); dfs2(1, 1);
build(1, 1, n);
for (int i = 1; i <= m; i++) {
scanf("%s", tp);
if (tp[0] == 'C') {
scanf("%d %d %d", &b1, &b2, &b3);
change(b1, b2, b3);
} else {
scanf("%d %d", &b1, &b2);
printf("%d
", ask(b1, b2));
}
}
return 0;
}