zoukankan      html  css  js  c++  java
  • 【数据结构】树链剖分详细讲解

     “在一棵树上进行路径的修改、求极值、求和”乍一看只要线段树就能轻松解决,实际上,仅凭线段树是不能搞定它的。我们需要用到一种貌似高级的复杂算法——树链剖分。

    树链剖分是把一棵树分割成若干条链,以便于维护信息的一种方法,其中最常用的是重链剖分(Heavy Path Decomposition,重路径分解),所以一般提到树链剖分或树剖都是指重链剖分。除此之外还有长链剖分和实链剖分等,本文暂不介绍。

    首先我们需要明确概念:

    • 重儿子:父亲节点的所有儿子中子树结点数目最多(size最大)的结点;
    • 轻儿子:父亲节点中除了重儿子以外的儿子;
    • 重边:父亲结点和重儿子连成的边;
    • 轻边:父亲节点和轻儿子连成的边;
    • 重链:由多条重边连接而成的路径;
    • 轻链:由多条轻边连接而成的路径;

    我们定义树上一个节点的子节点中子树最大的一个为它的重子节点,其余的为轻子节点。一个节点连向其重子节点的边称为重边,连向轻子节点的边则为轻边。如果把根节点看作轻的,那么从每个轻节点出发,不断向下走重边,都对应了一条链,于是我们把树剖分成了 (l) 条链,其中 (l) 是轻节点的数量。

    最近因为画图工具出了点问题,所以转载了Pecco学长的示意图(下面求LCA的方法的部分内容也来自Pecco学长)

    剖分后的树(重链)有如下性质:

    1. 对于节点数为 (n) 的树,从任意节点向上走到根节点,经过的轻边数量不会超过 (log n)

      这是因为当我们向下经过一条 轻边 时,所在子树的大小至少会除以二。所以说,对于树上的任意一条路径,把它拆分成从 (lca) 分别向两边往下走,分别最多走 (O(log n)) 次,树上的每条路径都可以被拆分成不超过 (O(log n)) 条重链。

    2. 树上每个节点都属于且仅属于一条重链

    重链开头的结点不一定是重子节点(因为重边是对于每一个结点都有定义的)。所有的重链将整棵树 完全剖分

    尽管树链部分看起来很难实现(的确有点繁琐),但我们可以用两个 DFS 来实现树链(树剖)。

    相关伪码(来自 OI wiki)

    第一个 DFS 记录每个结点的父节点(father)、深度(deep)、子树大小(size)、重子节点(hson)。

    [egin{array}{l} ext{TREE-BUILD }(u,dep) \ egin{array}{ll} 1 & u.hsongets 0 \ 2 & u.hson.sizegets 0 \ 3 & u.deepgets dep \ 4 & u.sizegets 1 \ 5 & extbf{for } ext{each }u ext{'s son }v \ 6 & qquad u.sizegets u.size + ext{TREE-BUILD }(v,dep+1) \ 7 & qquad v.fathergets u \ 8 & qquad extbf{if }v.size> u.hson.size \ 9 & qquad qquad u.hsongets v \ 10 & extbf{return } u.size end{array} end{array} ]

    第二个 DFS 记录所在链的链顶(top,应初始化为结点本身)、重边优先遍历时的 DFS 序(dfn)、DFS 序对应的节点编号(rank)。

    [egin{array}{l} ext{TREE-DECOMPOSITION }(u,top) \ egin{array}{ll} 1 & u.topgets top \ 2 & totgets tot+1\ 3 & u.dfngets tot \ 4 & rank(tot)gets u \ 5 & extbf{if }u.hson ext{ is not }0 \ 6 & qquad ext{TREE-DECOMPOSITION }(u.hson,top) \ 7 & qquad extbf{for } ext{each }u ext{'s son }v \ 8 & qquad qquad extbf{if }v ext{ is not }u.hson \ 9 & qquad qquad qquad ext{TREE-DECOMPOSITION }(v,v) end{array} end{array} ]

    以下为代码实现。

    我们先给出一些定义:

    • (fa(x)) 表示节点 (x) 在树上的父亲(也就是父节点)。
    • (dep(x)) 表示节点 (x) 在树上的深度。
    • (siz(x)) 表示节点 (x) 的子树的节点个数。
    • (son(x)) 表示节点 (x)重儿子
    • (top(x)) 表示节点 (x) 所在 重链 的顶部节点(深度最小)。
    • (dfn(x)) 表示节点 (x)DFS 序 ,也是其在线段树中的编号。
    • (rnk(x)) 表示 DFS 序所对应的节点编号,有 (rnk(dfn(x))=x)

    我们进行两遍 DFS 预处理出这些值,其中第一次 DFS 求出 (fa(x)) , (dep(x)) , (siz(x)) , (son(x)) ,第二次 DFS 求出 (top(x)) , (dfn(x)) , (rnk(x))

    // 当然树链写法不止一种,这个是我学习Oi wiki上知识点记录的模板代码
    void dfs1(int o) {
        son[o] = -1, siz[o] = 1;
        for (int j = h[o]; j; j = nxt[j])
            if (!dep[p[j]]) {
                dep[p[j]] = dep[o] + 1;
                fa[p[j]] = o;
                dfs1(p[j]);
                siz[o] += siz[p[j]];
                if (son[o] == -1 || siz[p[j]] > siz[son[o]])
                    son[o] = p[j];
            }
    }
    void dfs2(int o, int t) {
        top[o] = t;
        dfn[o] = ++cnt;
        rnk[cnt] = o;
        if (son[o] == -1)
            return;
        dfs2(son[o], t);  // 优先对重儿子进行 DFS,可以保证同一条重链上的点 DFS 序连续
        for (int j = h[o]; j; j = nxt[j])
            if (p[j] != son[o] && p[j] != fa[o])
                dfs2(p[j], p[j]);
    }
    
    // 写法2:来自Peocco学长,代码仅作学习使用
    void dfs1(int p, int d = 1){
        int Siz = 1,ma = 0;
        dep[p] = d;
        for(auto q : edges[p]){ // for循环写法和auto是C++11标准,竞赛可用
            dfs1(q,d + 1);
            fa[q] = p;
            Siz += sz[q];
            if(sz[q] > ma)
                hson[p] = q, ma = sz[q];// hson = 重儿子
        }
        sz[p] = Siz; 
    }
    // 需要先把根节点的top初始化为自身
    void dfs2(int p){
        for(auto q : edges[p]){
            if(!top[q]){
                if(q == hson[p])
                    top[q] = top[p];
               	else 
                    top[q] = q;
                dfs2(q);
            }
        }
    }
    

    以上这样便完成了剖分。

    学习到这里想想开头的那句话:

     “在一棵树上进行路径的修改、求极值、求和”乍一看只要线段树就能轻松解决,实际上,仅凭线段树是不能搞定它的。我们需要用到一种貌似高级的复杂算法——树链剖分。

    如果不能一下想不到线段树解决不了的问题的话不如看看这道题 ↓

    Hdu 3966 Aragorn's Story

    题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=3966

    题意:给一棵树,并给定各个点权的值,然后有3种操作:
      I C1 C2 K: 把C1与C2的路径上的所有点权值加上K
      D C1 C2 K:把C1与C2的路径上的所有点权值减去K
      Q C:查询节点编号为C的权值

      分析:典型的树链剖分题目,先进行剖分,然后用线段树去维护即可

    // Author : RioTian
    // Time : 20/11/30
    #include <bits/stdc++.h>
    using namespace std;
    
    #define lson l, m, rt << 1
    #define rson m + 1, r, rt << 1 | 1
    
    typedef long long ll;
    typedef int lld;
    
    stack<int> ss;
    const int maxn = 2e5 + 10;
    const int inf = ~0u >> 2;  // 1073741823
    int M[maxn << 2];
    int add[maxn << 2];
    
    struct node {
        int s, t, w, next;
    } edges[maxn << 1];
    
    int E, n;
    int Size[maxn], fa[maxn], heavy[maxn], head[maxn], vis[maxn];
    int dep[maxn], rev[maxn], num[maxn], cost[maxn], w[maxn];
    int Seg_size;
    
    int find(int x) {
        return fa[x] == x ? x : fa[x] = find(fa[x]);
    }
    
    void add_edge(int s, int t, int w) {
        edges[E].w = w;
        edges[E].s = s;
        edges[E].t = t;
        edges[E].next = head[s];
        head[s] = E++;
    }
    
    void dfs(int u, int f) {  //起点,父节点
        int mx = -1, e = -1;
        Size[u] = 1;
        for (int i = head[u]; i != -1; i = edges[i].next) {
            int v = edges[i].t;
            if (v == f)
                continue;
            edges[i].w = edges[i ^ 1].w = w[v];
            dep[v] = dep[u] + 1;
            rev[v] = i ^ 1;
            dfs(v, u);
            Size[u] += Size[v];
            if (Size[v] > mx)
                mx = Size[v], e = i;
        }
        heavy[u] = e;
        if (e != -1)
            fa[edges[e].t] = u;
    }
    
    inline void pushup(int rt) {
        M[rt] = M[rt << 1] + M[rt << 1 | 1];
    }
    
    void pushdown(int rt, int m) {
        if (add[rt]) {
            add[rt << 1] += add[rt];
            add[rt << 1 | 1] += add[rt];
            M[rt << 1] += add[rt] * (m - (m >> 1));
            M[rt << 1 | 1] += add[rt] * (m >> 1);
            add[rt] = 0;
        }
    }
    
    void built(int l, int r, int rt) {
        M[rt] = add[rt] = 0;
        if (l == r)
            return;
        int m = (r + l) >> 1;
        built(lson), built(rson);
    }
    
    void update(int L, int R, int val, int l, int r, int rt) {
        if (L <= l && r <= R) {
            M[rt] += val;
            add[rt] += val;
            return;
        }
        pushdown(rt, r - l + 1);
        int m = (l + r) >> 1;
        if (L <= m)
            update(L, R, val, lson);
        if (R > m)
            update(L, R, val, rson);
        pushup(rt);
    }
    
    lld query(int L, int R, int l, int r, int rt) {
        if (L <= l && r <= R)
            return M[rt];
        pushdown(rt, r - l + 1);
        int m = (l + r) >> 1;
        lld ret = 0;
        if (L <= m)
            ret += query(L, R, lson);
        if (R > m)
            ret += query(L, R, rson);
        return ret;
    }
    
    void prepare() {
        int i;
        built(1, n, 1);
        memset(num, -1, sizeof(num));
        dep[0] = 0;
        Seg_size = 0;
        for (i = 0; i < n; i++)
            fa[i] = i;
        dfs(0, 0);
        for (i = 0; i < n; i++) {
            if (heavy[i] == -1) {
                int pos = i;
                while (pos && edges[heavy[edges[rev[pos]].t]].t == pos) {
                    int t = rev[pos];
                    num[t] = num[t ^ 1] = ++Seg_size;
                    // printf("pos=%d  val=%d t=%d
    ", Seg_size, edge[t].w, t);
                    update(Seg_size, Seg_size, edges[t].w, 1, n, 1);
                    pos = edges[t].t;
                }
            }
        }
    }
    
    int lca(int u, int v) {
        while (1) {
            int a = find(u), b = find(v);
            if (a == b)
                return dep[u] < dep[v] ? u : v;  // a,b在同一条重链上
            else if (dep[a] >= dep[b])
                u = edges[rev[a]].t;
            else
                v = edges[rev[b]].t;
        }
    }
    
    void CH(int u, int lca, int val) {
        while (u != lca) {
            int r = rev[u];  // printf("r=%d
    ",r);
            if (num[r] == -1)
                edges[r].w += val, u = edges[r].t;
            else {
                int p = fa[u];
                if (dep[p] < dep[lca])
                    p = lca;
                int l = num[r];
                r = num[heavy[p]];
                update(l, r, val, 1, n, 1);
                u = p;
            }
        }
    }
    
    void change(int u, int v, int val) {
        int p = lca(u, v);
        // printf("p=%d
    ",p);
        CH(u, p, val);
        CH(v, p, val);
        if (p) {
            int r = rev[p];
            if (num[r] == -1) {
                edges[r ^ 1].w += val;  //在此处发现了我代码的重大bug
                edges[r].w += val;
            } else
                update(num[r], num[r], val, 1, n, 1);
        }  //根节点,特判
        else
            w[p] += val;
    }
    
    lld solve(int u) {
        if (!u)
            return w[u];  //根节点,特判
        else {
            int r = rev[u];
            if (num[r] == -1)
                return edges[r].w;
            else
                return query(num[r], num[r], 1, n, 1);
        }
    }
    
    int main() {
        // freopen("in.txt", "r", stdin);
        ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);
        int t, i, a, b, c, m, ca = 1, p;
        while (cin >> n >> m >> p) {
            memset(head, -1, sizeof(head));
            E = 0;
            for (int i = 0; i < n; ++i)
                cin >> w[i];
            for (int i = 0; i < m; ++i) {
                cin >> a >> b;
                a--, b--;
                add_edge(a, b, 0), add_edge(b, a, 0);
            }
            prepare();  // 预处理
            string op;
            while (p--) {
                cin >> op;
                if (op[0] == 'I') {  //区间添加
                    cin >> a >> b >> c;
                    a--, b--;
                    change(a, b, c);
                } else if (op[0] == 'D') {  //区间减少
                    cin >> a >> b >> c;
                    a--, b--;
                    change(a, b, -c);
                } else {  //查询
                    cin >> a;
                    a--;
                    cout << solve(a) << endl;
                }
            }
        }
        return 0;
    }
    

    由于数据很大,建议使用快读,而不是像我一样用 cin(差了近500ms了)

    折叠代码是千千dalao的解法:

    Code
    //千千dalao解法
    #include<bits/stdc++.h>
    using namespace std;
    typedef long long LL;
    const int maxn = 50010;
    struct Edge {
        int to;
        int next;
    } edge[maxn << 1];
    int head[maxn], tot;  //链式前向星存储
    int top[maxn];        // v所在重链的顶端节点
    int fa[maxn];         //父亲节点
    int deep[maxn];       //节点深度
    int num[maxn];        //以v为根的子树节点数
    int p[maxn];          // v与其父亲节点的连边在线段树中的位置
    int fp[maxn];         //与p[]数组相反
    int son[maxn];        //重儿子
    int pos;
    int w[maxn];
    int ad[maxn << 2];  //树状数组
    int n;              //节点数目
    void init() {
        memset(head, -1, sizeof(head));
        memset(son, -1, sizeof(son));
        tot = 0;
        pos = 1;  //因为使用树状数组,所以我们pos初始值从1开始
    }
    void addedge(int u, int v) {
        edge[tot].to = v;
        edge[tot].next = head[u];
        head[u] = tot++;
    }
    //第一遍dfs,求出 fa,deep,num,son (u为当前节点,pre为其父节点,d为深度)
    void dfs1(int u, int pre, int d) {
        deep[u] = d;
        fa[u] = pre;
        num[u] = 1;
        //遍历u的邻接点
        for (int i = head[u]; i != -1; i = edge[i].next) {
            int v = edge[i].to;
            if (v != pre) {
                dfs1(v, u, d + 1);
                num[u] += num[v];
                if (son[u] == -1 || num[v] > num[son[u]])  //寻找重儿子
                    son[u] = v;
            }
        }
    }
    //第二遍dfs,求出 top,p
    void dfs2(int u, int sp) {
        top[u] = sp;
        p[u] = pos++;
        fp[p[u]] = u;
        if (son[u] != -1)  //如果当前点存在重儿子,继续延伸形成重链
            dfs2(son[u], sp);
        else
            return;
        for (int i = head[u]; i != -1; i = edge[i].next) {
            int v = edge[i].to;
            if (v != son[u] && v != fa[u])  //遍历所有轻儿子新建重链
                dfs2(v, v);
        }
    }
    int lowbit(int x) {
        return x & -x;
    }
    //查询
    int query(int i) {
        int s = 0;
        while (i > 0) {
            s += ad[i];
            i -= lowbit(i);
        }
        return s;
    }
    //增加
    void add(int i, int val) {
        while (i <= n) {
            ad[i] += val;
            i += lowbit(i);
        }
    }
    void update(int u, int v, int val) {
        int f1 = top[u], f2 = top[v];
        while (f1 != f2) {
            if (deep[f1] < deep[f2]) {
                swap(f1, f2);
                swap(u, v);
            }
            //因为区间减法成立,所以我们把对某个区间[f1,u]
            //的更新拆分为 [0,f1] 和 [0,u] 的操作
            add(p[f1], val);
            add(p[u] + 1, -val);
            u = fa[f1];
            f1 = top[u];
        }
        if (deep[u] > deep[v])
            swap(u, v);
        add(p[u], val);
        add(p[v] + 1, -val);
    }
    int main() {
        ios::sync_with_stdio(false);
        int m, ps;
        while (cin >> n >> m >> ps) {
            int a, b, c;
            for (int i = 1; i <= n; i++)
                cin >> w[i];
            init();
            for (int i = 0; i < m; i++) {
                cin >> a >> b;
                addedge(a, b);
                addedge(b, a);
            }
            dfs1(1, 0, 0);
            dfs2(1, 1);
            memset(ad, 0, sizeof(ad));
            for (int i = 1; i <= n; i++) {
                add(p[i], w[i]);
                add(p[i] + 1, -w[i]);
            }
            for (int i = 0; i < ps; i++) {
                char op;
                cin >> op;
                if (op == 'Q') {
                    cin >> a;
                    cout << query(p[a]) << endl;
                } else {
                    cin >> a >> b >> c;
                    if (op == 'D')
                        c = -c;
                    update(a, b, c);
                }
            }
        }
        return 0;
    }
    

    利用树链求LCA

    这个部分参考了Peocco学长,十分感谢

    在这道经典题中,求了LCA,但为什么树剖就可以求LCA呢?

    树剖可以单次 (O(log n))! 地求LCA,且常数较小。假如我们要求两个节点的LCA,如果它们在同一条链上,那直接输出深度较小的那个节点就可以了。

    否则,LCA要么在链头深度较小的那条链上,要么就是两个链头的父节点的LCA,但绝不可能在链头深度较大的那条链上[1]。所以我们可以直接把链头深度较大的节点用其链头的父节点代替,然后继续求它与另一者的LCA。

    由于在链上我们可以 (O(1)) 地跳转,每条链间由轻边连接,而经过轻边的次数又不超过 [公式] ,所以我们实现了 (O(log n)) 的LCA查询。

    int lca(int a, int b) {
        while (top[a] != top[b]) {
            if (dep[top[a]] > dep[top[b]])
                a = fa[top[a]];
            else
                b = fa[top[b]];
        }
        return (dep[a] > dep[b] ? b : a);
    }
    

    结合数据结构

    在进行了树链剖分后,我们便可以配合线段树等数据结构维护树上的信息,这需要我们改一下第二次 DFS 的代码,我们用dfsn数组记录每个点的dfs序,用madfsn数组记录每棵子树的最大dfs序:(这里有点像连通图的知识了)

    // 需要先把根节点的top初始化为自身
    int cnt;
    void dfs2(int p) {
        madfsn[p] = dfsn[p] = ++cnt;
        if (hson[p] != 0) {
            top[hson[p]] = top[p];
            dfs2(hson[p]);
            madfsn[p] = max(madfsn[p], madfsn[hson[p]]);
        }
        for (auto q : edges[p])
            if (!top[q]) {
                top[q] = q;
                dfs2(q);
                madfsn[p] = max(madfsn[p], madfsn[q]);
            }
    }
    

    注意到,每棵子树的dfs序都是连续的,且根节点dfs序最小;而且,如果我们优先遍历重子节点,那么同一条链上的节点的dfs序也是连续的,且链头节点dfs序最小

    连通树(雾)

    所以就可以用线段树等数据结构维护区间信息(以点权的和为例),例如路径修改(类似于求LCA的过程):

    void update_path(int x, int y, int z) {
        while (top[x] != top[y]) {
            if (dep[top[x]] > dep[top[y]]) {
                update(dfsn[top[x]], dfsn[x], z);
                x = fa[top[x]];
            } else {
                update(dfsn[top[y]], dfsn[y], z);
                y = fa[top[y]];
            }
        }
        if (dep[x] > dep[y])
            update(dfsn[y], dfsn[x], z);
        else
            update(dfsn[x], dfsn[y], z);
    }
    

    路径查询:

    int query_path(int x, int y) {
        int ans = 0;
        while (top[x] != top[y]) {
            if (dep[top[x]] > dep[top[y]]) {
                ans += query(dfsn[top[x]], dfsn[x]);
                x = fa[top[x]];
            } else {
                ans += query(dfsn[top[y]], dfsn[y]);
                y = fa[top[y]];
            }
        }
        if (dep[x] > dep[y])
            ans += query(dfsn[y], dfsn[x]);
        else
            ans += query(dfsn[x], dfsn[y]);
        return ans;
    }
    

    子树修改(更新):

    void update_subtree(int x, int z){
        update(dfsn[x], madfsn[x], z);
    }
    

    子树查询:

    int query_subtree(int x){
        return query(dfsn[x], madfsn[x]);
    }
    

    需要注意,建线段树的时候不是按节点编号建,而是按dfs序建,类似这样:

    for (int i = 1; i <= n; ++i)
        B[i] = read();
    // ...
    for (int i = 1; i <= n; ++i)
        A[dfsn[i]] = B[i];
    build();
    

    当然,不仅可以用线段树维护,有些题也可以使用珂朵莉树等数据结构(要求数据不卡珂朵莉树,如这道)。此外,如果需要维护的是边权而不是点权,把每条边的边权下放到深度较深的那个节点处即可,但是查询、修改的时候要注意略过最后一个点。

    写在最后:

    OI wiki上有一些推荐做的列题,但每个都需要比较多的时间+耐心去完成,所以这里推荐几个必做的题:

    SPOJ QTREE – Query on a tree (树链剖分):千千dalao的题解报告

    HDU 3966 Aragorn’s Story (树链剖分):建议先看一遍我的解法再独立完成。

    参考

    洛谷日报:https://zhuanlan.zhihu.com/p/41082337

    OI wiki:https://oi-wiki.org/graph/hld/

    Pecco学长:https://www.zhihu.com/people/one-seventh

    千千:https://www.dreamwings.cn/hdu3966/4798.html


    1. 设top[a]的深度≤top[b]的深度,且c=lca(a,b)在b所在的链上;那么c是a和b的祖先且c的深度≥top[b]的深度,那么c的深度≥top[a]的深度。c是a的祖先,top[a]也是a的祖先,c的深度大于等于top[a],那c必然在连接top[a]和a的这条链上,与前提矛盾 ↩︎

  • 相关阅读:
    超有爱的并查集
    写给想当程序员的朋友
    POJ 1961 字符串 KMP (i-next[i])
    POJ 2406 KMP算法next数组理解
    POJ 2387 Bellman双重边
    POJ 1917 字符串替换
    POJ 1062 坑爹的聘礼(枚举等级差选择性找边)
    Linux下libxml2的使用
    浙大pat 1003
    判定一棵二叉树是否是二叉搜索树
  • 原文地址:https://www.cnblogs.com/RioTian/p/14063710.html
Copyright © 2011-2022 走看看