1 树剖介绍
树链剖分是用来处理树上路径问题。(如路径和)
这里以P2590 [ZJOI2008]树的统计为模板来讲解。
因为有修改所以暴力是肯定不行的。这样我们就要请出我们今天的主角——树链剖分了。
树链剖分其实就是把一颗数剖成很多条链,给链上的节点重新编号使其编号连续。这样就可在这条链上用其它处理线段的数据结构(一般是线段树)处理了。
常用的有轻重链剖分。就是把树剖成多条重链和轻链。
对于每一个节点,他会有一个重儿子和若干个轻儿子。重儿子就延续当前的重链,轻儿子则作为新的一条重链的开始。重儿子是指子树大小最大的儿子,其他的都是轻儿子。
这样这棵树就被剖成了很多条链。显然每个节点都属于一条重链。
重儿子指的就是子树大小最大的儿子,轻儿子是其它儿子。
如图,红边是重链,蓝边是轻链,星星是重链开始
我们还需要重新编号,其实只要按照重儿子dfs的dfs序即可。
为了后面的查询,我们还得存储一个dep代表深度,fa代表在树上的父亲以及bl代表当前重链的开始。
我们来看具体的代码实现。
代码有两个dfs(具体看注释)。
void dfs1(int x) { sz[x] = 1;//sz是存储子树大小的 for (int i = head[x]; i; i = edge[i].pre) { int y = edge[i].to; if (y == fa[x]) continue; //预处理深度和父节点 dep[y] = dep[x] + 1; fa[y] = x; dfs1(y, x); sz[x] += sz[y];//统计子树大小 } }
void dfs2(int x, int chain) { int k = 0;//重儿子 dfn[x] = ++len; bl[x] = chain;//chain是树链开始 for (int i = head[x]; i; i = edge[i].pre) { int y = edge[i].to; if (dep[y] < dep[x]) continue; if (sz[y] > sz[k]) { k = y; } }//查找重儿子 if (k) dfs2(k, chain);//如果有重儿子则先递归重儿子 for (int i = head[x]; i; i = edge[i].pre) { int y = edge[i].to; if (y == fa[x] || y == k) continue; dfs2(y, y);//轻儿子是重链的开始 } }
那么如何查询呢?
其实也很简单
我们以查询“和”为例来介绍。
假设我们我们要查询 $u->v$ 的路径上的“和”。
我们首先判断它们在不在同一条重链上
如果在,直接线段树查询这一条链,返回即可。
如果不在,则需要其中一个节点 $u$ 统计从 $bl[u]->u$ 这条链上的答案,然后 $u$ 跳到 $fa[bl[u]]$。
记住这个 $u$ 必须是 $bl的dep$大的一个,否则有可能得不到正确答案。
比如上图,如果跳 $v$ 则会跳到 $root$ 就的不到答案而且会死循环。
代码实现
int Qsum(int u, int v) { int sum = 0; while (bl[u] != bl[v]) { if (dep[bl[u]] < dep[bl[v]]) swap(u, v); sum += ask_sum(1, pos[bl[u]], pos[u]);//ask是线段树查询区间和 u = fa[bl[u]]; } if (dep[u] > dep[v]) { swap(u, v); } sum += ask_sum(1, pos[u], pos[v]);//在一条链上也别忘统计 return sum; }
注意:线段树维护就是正常的区间维护(对dfn)。
时间复杂度分析:
在重链上线段树求答案是 $O(log{N})$ 的,而对于跳轻链,因为轻儿子的子树大小至多是一半,所以级别是 $O(log{N})$ 的。
完整代码:
#include <iostream> #include <cstdio> using namespace std; const int N = 30010, inf = 0x7f7f7f7f; struct seg_tree{ int maxi, val, l, r; }st[4 * N]; struct node{ int pre, to; }edge[2 * N]; int head[N], tot; int n; int a, b, dep[N], fa[N], bl[N], sz[N], w[N], pos[N]; int len, QQ; void dfs1(int x, int f) { sz[x] = 1; for (int i = head[x]; i; i = edge[i].pre) { int y = edge[i].to; if (y == f) continue; dep[y] = dep[x] + 1; fa[y] = x; dfs1(y, x); sz[x] += sz[y]; } } void dfs2(int x, int chain) { int k = 0; pos[x] = ++len; bl[x] = chain; for (int i = head[x]; i; i = edge[i].pre) { int y = edge[i].to; if (dep[y] < dep[x]) continue; if (sz[y] > sz[k]) { k = y; } } if (k) dfs2(k, chain); for (int i = head[x]; i; i = edge[i].pre) { int y = edge[i].to; if (dep[y] < dep[x] || y == k) continue; dfs2(y, y); } } void build(int x, int l, int r) { st[x].l = l, st[x].r = r; if (l == r) return; int mid = (l + r) >> 1; build(x << 1, l, mid); build(x << 1 | 1, mid + 1, r); } void change(int x, int p, int v) { int l = st[x].l, r = st[x].r; if (l == r) { st[x].val = st[x].maxi = v; return; } int mid = (l + r) >> 1; if (p <= mid) change(x << 1, p, v); else change(x << 1 | 1, p, v); st[x].val = st[x << 1].val + st[x << 1 | 1].val; st[x].maxi = max(st[x << 1].maxi, st[x << 1 | 1].maxi); } int ask_max(int x, int L, int R) { int l = st[x].l, r = st[x].r; if (L <= l && r <= R) { return st[x].maxi; } if (l > R || r < L) return -inf; return max(ask_max(x << 1, L, R), ask_max(x << 1 | 1, L, R)); } int ask_sum(int x, int L, int R) { int l = st[x].l, r = st[x].r; if (L <= l && r <= R) { return st[x].val; } if (l > R || r < L) return 0; return ask_sum(x << 1, L, R) + ask_sum(x << 1 | 1, L, R); } int Qsum(int u, int v) { int sum = 0; while (bl[u] != bl[v]) { if (dep[bl[u]] < dep[bl[v]]) swap(u, v); sum += ask_sum(1, pos[bl[u]], pos[u]); u = fa[bl[u]]; } if (dep[u] > dep[v]) { swap(u, v); } sum += ask_sum(1, pos[u], pos[v]); return sum; } int Qmax(int u, int v) { int maxi = -inf; while (bl[u] != bl[v]) { if (dep[bl[u]] < dep[bl[v]]) swap(u, v); maxi = max(maxi, ask_max(1, pos[bl[u]], pos[u])); u = fa[bl[u]]; } if (dep[u] > dep[v]) { swap(u, v); } maxi = max(maxi, ask_max(1, pos[u], pos[v])); return maxi; } void add(int u, int v) { edge[++tot] = node{head[u], v}; head[u] = tot; } int main() { cin >> n; for (int i = 1, a, b; i < n; i++) { cin >> a >> b; add(a, b); add(b, a); } dfs1(1, 0); dfs2(1, 1); build(1, 1, len); for (int i = 1; i <= n; i++) { cin >> w[i]; change(1, pos[i], w[i]); } cin >> QQ; while (QQ--) { string opt; int u, v; cin >> opt >> u >> v; if (opt == "QMAX") { cout << Qmax(u, v) << " "; } else if (opt == "QSUM") { cout << Qsum(u, v) << " "; } else { change(1, pos[u], v); } } return 0; }
练习:
因为时间有限,所以部分习题不配有题解。如有需要,请在下方留言谢谢。
2 子树问题
我们来看这道P3384 【模板】轻重链剖分,它还叫我们输出子树和,这可怎么办呢?
我们发现按照上述方式给节点编号那么一棵子树的编号也是连续的(即一段区间),那么我们直接区间查询修改即可。
练习:
3 LCA
理解了上面讲的东西那这个就很好想了。
一直跳到两点在同一重链上,然后返回较浅的点,$O(log{N})$。
int LCA(int u, int v) { while (bl[u] != bl[v]) { if (dep[bl[u]] < dep[bl[v]]) swap(u, v); u = fa[bl[u]]; } return dep[u] < dep[v] ? u : v; }
树剖求LCA的常数较小,可用来卡常。
习题: