zoukankan      html  css  js  c++  java
  • 【数据结构】树链剖分详细讲解

     “在一棵树上进行路径的修改、求极值、求和”乍一看只要线段树就能轻松解决,实际上,仅凭线段树是不能搞定它的。我们需要用到一种貌似高级的复杂算法——树链剖分。

    树链剖分是把一棵树分割成若干条链,以便于维护信息的一种方法,其中最常用的是重链剖分(Heavy Path Decomposition,重路径分解),所以一般提到树链剖分或树剖都是指重链剖分。除此之外还有长链剖分和实链剖分等,本文暂不介绍。

    首先我们需要明确概念:

    • 重儿子:父亲节点的所有儿子中子树结点数目最多(size最大)的结点;
    • 轻儿子:父亲节点中除了重儿子以外的儿子;
    • 重边:父亲结点和重儿子连成的边;
    • 轻边:父亲节点和轻儿子连成的边;
    • 重链:由多条重边连接而成的路径;
    • 轻链:由多条轻边连接而成的路径;

    我们定义树上一个节点的子节点中子树最大的一个为它的重子节点,其余的为轻子节点。一个节点连向其重子节点的边称为重边,连向轻子节点的边则为轻边。如果把根节点看作轻的,那么从每个轻节点出发,不断向下走重边,都对应了一条链,于是我们把树剖分成了 (l) 条链,其中 (l) 是轻节点的数量。

    最近因为画图工具出了点问题,所以转载了Pecco学长的示意图(下面求LCA的方法的部分内容也来自Pecco学长)

    剖分后的树(重链)有如下性质:

    1. 对于节点数为 (n) 的树,从任意节点向上走到根节点,经过的轻边数量不会超过 (log n)

      这是因为当我们向下经过一条 轻边 时,所在子树的大小至少会除以二。所以说,对于树上的任意一条路径,把它拆分成从 (lca) 分别向两边往下走,分别最多走 (O(log n)) 次,树上的每条路径都可以被拆分成不超过 (O(log n)) 条重链。

    2. 树上每个节点都属于且仅属于一条重链

    重链开头的结点不一定是重子节点(因为重边是对于每一个结点都有定义的)。所有的重链将整棵树 完全剖分

    尽管树链部分看起来很难实现(的确有点繁琐),但我们可以用两个 DFS 来实现树链(树剖)。

    相关伪码(来自 OI wiki)

    第一个 DFS 记录每个结点的父节点(father)、深度(deep)、子树大小(size)、重子节点(hson)。

    [egin{array}{l} ext{TREE-BUILD }(u,dep) \ egin{array}{ll} 1 & u.hsongets 0 \ 2 & u.hson.sizegets 0 \ 3 & u.deepgets dep \ 4 & u.sizegets 1 \ 5 & extbf{for } ext{each }u ext{'s son }v \ 6 & qquad u.sizegets u.size + ext{TREE-BUILD }(v,dep+1) \ 7 & qquad v.fathergets u \ 8 & qquad extbf{if }v.size> u.hson.size \ 9 & qquad qquad u.hsongets v \ 10 & extbf{return } u.size end{array} end{array} ]

    第二个 DFS 记录所在链的链顶(top,应初始化为结点本身)、重边优先遍历时的 DFS 序(dfn)、DFS 序对应的节点编号(rank)。

    [egin{array}{l} ext{TREE-DECOMPOSITION }(u,top) \ egin{array}{ll} 1 & u.topgets top \ 2 & totgets tot+1\ 3 & u.dfngets tot \ 4 & rank(tot)gets u \ 5 & extbf{if }u.hson ext{ is not }0 \ 6 & qquad ext{TREE-DECOMPOSITION }(u.hson,top) \ 7 & qquad extbf{for } ext{each }u ext{'s son }v \ 8 & qquad qquad extbf{if }v ext{ is not }u.hson \ 9 & qquad qquad qquad ext{TREE-DECOMPOSITION }(v,v) end{array} end{array} ]

    以下为代码实现。

    我们先给出一些定义:

    • (fa(x)) 表示节点 (x) 在树上的父亲(也就是父节点)。
    • (dep(x)) 表示节点 (x) 在树上的深度。
    • (siz(x)) 表示节点 (x) 的子树的节点个数。
    • (son(x)) 表示节点 (x)重儿子
    • (top(x)) 表示节点 (x) 所在 重链 的顶部节点(深度最小)。
    • (dfn(x)) 表示节点 (x)DFS 序 ,也是其在线段树中的编号。
    • (rnk(x)) 表示 DFS 序所对应的节点编号,有 (rnk(dfn(x))=x)

    我们进行两遍 DFS 预处理出这些值,其中第一次 DFS 求出 (fa(x)) , (dep(x)) , (siz(x)) , (son(x)) ,第二次 DFS 求出 (top(x)) , (dfn(x)) , (rnk(x))

    // 当然树链写法不止一种,这个是我学习Oi wiki上知识点记录的模板代码
    void dfs1(int o) {
        son[o] = -1, siz[o] = 1;
        for (int j = h[o]; j; j = nxt[j])
            if (!dep[p[j]]) {
                dep[p[j]] = dep[o] + 1;
                fa[p[j]] = o;
                dfs1(p[j]);
                siz[o] += siz[p[j]];
                if (son[o] == -1 || siz[p[j]] > siz[son[o]])
                    son[o] = p[j];
            }
    }
    void dfs2(int o, int t) {
        top[o] = t;
        dfn[o] = ++cnt;
        rnk[cnt] = o;
        if (son[o] == -1)
            return;
        dfs2(son[o], t);  // 优先对重儿子进行 DFS,可以保证同一条重链上的点 DFS 序连续
        for (int j = h[o]; j; j = nxt[j])
            if (p[j] != son[o] && p[j] != fa[o])
                dfs2(p[j], p[j]);
    }
    
    // 写法2:来自Peocco学长,代码仅作学习使用
    void dfs1(int p, int d = 1){
        int Siz = 1,ma = 0;
        dep[p] = d;
        for(auto q : edges[p]){ // for循环写法和auto是C++11标准,竞赛可用
            dfs1(q,d + 1);
            fa[q] = p;
            Siz += sz[q];
            if(sz[q] > ma)
                hson[p] = q, ma = sz[q];// hson = 重儿子
        }
        sz[p] = Siz; 
    }
    // 需要先把根节点的top初始化为自身
    void dfs2(int p){
        for(auto q : edges[p]){
            if(!top[q]){
                if(q == hson[p])
                    top[q] = top[p];
               	else 
                    top[q] = q;
                dfs2(q);
            }
        }
    }
    

    以上这样便完成了剖分。

    学习到这里想想开头的那句话:

     “在一棵树上进行路径的修改、求极值、求和”乍一看只要线段树就能轻松解决,实际上,仅凭线段树是不能搞定它的。我们需要用到一种貌似高级的复杂算法——树链剖分。

    如果不能一下想不到线段树解决不了的问题的话不如看看这道题 ↓

    Hdu 3966 Aragorn's Story

    题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=3966

    题意:给一棵树,并给定各个点权的值,然后有3种操作:
      I C1 C2 K: 把C1与C2的路径上的所有点权值加上K
      D C1 C2 K:把C1与C2的路径上的所有点权值减去K
      Q C:查询节点编号为C的权值

      分析:典型的树链剖分题目,先进行剖分,然后用线段树去维护即可

    // Author : RioTian
    // Time : 20/11/30
    #include <bits/stdc++.h>
    using namespace std;
    
    #define lson l, m, rt << 1
    #define rson m + 1, r, rt << 1 | 1
    
    typedef long long ll;
    typedef int lld;
    
    stack<int> ss;
    const int maxn = 2e5 + 10;
    const int inf = ~0u >> 2;  // 1073741823
    int M[maxn << 2];
    int add[maxn << 2];
    
    struct node {
        int s, t, w, next;
    } edges[maxn << 1];
    
    int E, n;
    int Size[maxn], fa[maxn], heavy[maxn], head[maxn], vis[maxn];
    int dep[maxn], rev[maxn], num[maxn], cost[maxn], w[maxn];
    int Seg_size;
    
    int find(int x) {
        return fa[x] == x ? x : fa[x] = find(fa[x]);
    }
    
    void add_edge(int s, int t, int w) {
        edges[E].w = w;
        edges[E].s = s;
        edges[E].t = t;
        edges[E].next = head[s];
        head[s] = E++;
    }
    
    void dfs(int u, int f) {  //起点,父节点
        int mx = -1, e = -1;
        Size[u] = 1;
        for (int i = head[u]; i != -1; i = edges[i].next) {
            int v = edges[i].t;
            if (v == f)
                continue;
            edges[i].w = edges[i ^ 1].w = w[v];
            dep[v] = dep[u] + 1;
            rev[v] = i ^ 1;
            dfs(v, u);
            Size[u] += Size[v];
            if (Size[v] > mx)
                mx = Size[v], e = i;
        }
        heavy[u] = e;
        if (e != -1)
            fa[edges[e].t] = u;
    }
    
    inline void pushup(int rt) {
        M[rt] = M[rt << 1] + M[rt << 1 | 1];
    }
    
    void pushdown(int rt, int m) {
        if (add[rt]) {
            add[rt << 1] += add[rt];
            add[rt << 1 | 1] += add[rt];
            M[rt << 1] += add[rt] * (m - (m >> 1));
            M[rt << 1 | 1] += add[rt] * (m >> 1);
            add[rt] = 0;
        }
    }
    
    void built(int l, int r, int rt) {
        M[rt] = add[rt] = 0;
        if (l == r)
            return;
        int m = (r + l) >> 1;
        built(lson), built(rson);
    }
    
    void update(int L, int R, int val, int l, int r, int rt) {
        if (L <= l && r <= R) {
            M[rt] += val;
            add[rt] += val;
            return;
        }
        pushdown(rt, r - l + 1);
        int m = (l + r) >> 1;
        if (L <= m)
            update(L, R, val, lson);
        if (R > m)
            update(L, R, val, rson);
        pushup(rt);
    }
    
    lld query(int L, int R, int l, int r, int rt) {
        if (L <= l && r <= R)
            return M[rt];
        pushdown(rt, r - l + 1);
        int m = (l + r) >> 1;
        lld ret = 0;
        if (L <= m)
            ret += query(L, R, lson);
        if (R > m)
            ret += query(L, R, rson);
        return ret;
    }
    
    void prepare() {
        int i;
        built(1, n, 1);
        memset(num, -1, sizeof(num));
        dep[0] = 0;
        Seg_size = 0;
        for (i = 0; i < n; i++)
            fa[i] = i;
        dfs(0, 0);
        for (i = 0; i < n; i++) {
            if (heavy[i] == -1) {
                int pos = i;
                while (pos && edges[heavy[edges[rev[pos]].t]].t == pos) {
                    int t = rev[pos];
                    num[t] = num[t ^ 1] = ++Seg_size;
                    // printf("pos=%d  val=%d t=%d
    ", Seg_size, edge[t].w, t);
                    update(Seg_size, Seg_size, edges[t].w, 1, n, 1);
                    pos = edges[t].t;
                }
            }
        }
    }
    
    int lca(int u, int v) {
        while (1) {
            int a = find(u), b = find(v);
            if (a == b)
                return dep[u] < dep[v] ? u : v;  // a,b在同一条重链上
            else if (dep[a] >= dep[b])
                u = edges[rev[a]].t;
            else
                v = edges[rev[b]].t;
        }
    }
    
    void CH(int u, int lca, int val) {
        while (u != lca) {
            int r = rev[u];  // printf("r=%d
    ",r);
            if (num[r] == -1)
                edges[r].w += val, u = edges[r].t;
            else {
                int p = fa[u];
                if (dep[p] < dep[lca])
                    p = lca;
                int l = num[r];
                r = num[heavy[p]];
                update(l, r, val, 1, n, 1);
                u = p;
            }
        }
    }
    
    void change(int u, int v, int val) {
        int p = lca(u, v);
        // printf("p=%d
    ",p);
        CH(u, p, val);
        CH(v, p, val);
        if (p) {
            int r = rev[p];
            if (num[r] == -1) {
                edges[r ^ 1].w += val;  //在此处发现了我代码的重大bug
                edges[r].w += val;
            } else
                update(num[r], num[r], val, 1, n, 1);
        }  //根节点,特判
        else
            w[p] += val;
    }
    
    lld solve(int u) {
        if (!u)
            return w[u];  //根节点,特判
        else {
            int r = rev[u];
            if (num[r] == -1)
                return edges[r].w;
            else
                return query(num[r], num[r], 1, n, 1);
        }
    }
    
    int main() {
        // freopen("in.txt", "r", stdin);
        ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);
        int t, i, a, b, c, m, ca = 1, p;
        while (cin >> n >> m >> p) {
            memset(head, -1, sizeof(head));
            E = 0;
            for (int i = 0; i < n; ++i)
                cin >> w[i];
            for (int i = 0; i < m; ++i) {
                cin >> a >> b;
                a--, b--;
                add_edge(a, b, 0), add_edge(b, a, 0);
            }
            prepare();  // 预处理
            string op;
            while (p--) {
                cin >> op;
                if (op[0] == 'I') {  //区间添加
                    cin >> a >> b >> c;
                    a--, b--;
                    change(a, b, c);
                } else if (op[0] == 'D') {  //区间减少
                    cin >> a >> b >> c;
                    a--, b--;
                    change(a, b, -c);
                } else {  //查询
                    cin >> a;
                    a--;
                    cout << solve(a) << endl;
                }
            }
        }
        return 0;
    }
    

    由于数据很大,建议使用快读,而不是像我一样用 cin(差了近500ms了)

    折叠代码是千千dalao的解法:

    Code
    //千千dalao解法
    #include<bits/stdc++.h>
    using namespace std;
    typedef long long LL;
    const int maxn = 50010;
    struct Edge {
        int to;
        int next;
    } edge[maxn << 1];
    int head[maxn], tot;  //链式前向星存储
    int top[maxn];        // v所在重链的顶端节点
    int fa[maxn];         //父亲节点
    int deep[maxn];       //节点深度
    int num[maxn];        //以v为根的子树节点数
    int p[maxn];          // v与其父亲节点的连边在线段树中的位置
    int fp[maxn];         //与p[]数组相反
    int son[maxn];        //重儿子
    int pos;
    int w[maxn];
    int ad[maxn << 2];  //树状数组
    int n;              //节点数目
    void init() {
        memset(head, -1, sizeof(head));
        memset(son, -1, sizeof(son));
        tot = 0;
        pos = 1;  //因为使用树状数组,所以我们pos初始值从1开始
    }
    void addedge(int u, int v) {
        edge[tot].to = v;
        edge[tot].next = head[u];
        head[u] = tot++;
    }
    //第一遍dfs,求出 fa,deep,num,son (u为当前节点,pre为其父节点,d为深度)
    void dfs1(int u, int pre, int d) {
        deep[u] = d;
        fa[u] = pre;
        num[u] = 1;
        //遍历u的邻接点
        for (int i = head[u]; i != -1; i = edge[i].next) {
            int v = edge[i].to;
            if (v != pre) {
                dfs1(v, u, d + 1);
                num[u] += num[v];
                if (son[u] == -1 || num[v] > num[son[u]])  //寻找重儿子
                    son[u] = v;
            }
        }
    }
    //第二遍dfs,求出 top,p
    void dfs2(int u, int sp) {
        top[u] = sp;
        p[u] = pos++;
        fp[p[u]] = u;
        if (son[u] != -1)  //如果当前点存在重儿子,继续延伸形成重链
            dfs2(son[u], sp);
        else
            return;
        for (int i = head[u]; i != -1; i = edge[i].next) {
            int v = edge[i].to;
            if (v != son[u] && v != fa[u])  //遍历所有轻儿子新建重链
                dfs2(v, v);
        }
    }
    int lowbit(int x) {
        return x & -x;
    }
    //查询
    int query(int i) {
        int s = 0;
        while (i > 0) {
            s += ad[i];
            i -= lowbit(i);
        }
        return s;
    }
    //增加
    void add(int i, int val) {
        while (i <= n) {
            ad[i] += val;
            i += lowbit(i);
        }
    }
    void update(int u, int v, int val) {
        int f1 = top[u], f2 = top[v];
        while (f1 != f2) {
            if (deep[f1] < deep[f2]) {
                swap(f1, f2);
                swap(u, v);
            }
            //因为区间减法成立,所以我们把对某个区间[f1,u]
            //的更新拆分为 [0,f1] 和 [0,u] 的操作
            add(p[f1], val);
            add(p[u] + 1, -val);
            u = fa[f1];
            f1 = top[u];
        }
        if (deep[u] > deep[v])
            swap(u, v);
        add(p[u], val);
        add(p[v] + 1, -val);
    }
    int main() {
        ios::sync_with_stdio(false);
        int m, ps;
        while (cin >> n >> m >> ps) {
            int a, b, c;
            for (int i = 1; i <= n; i++)
                cin >> w[i];
            init();
            for (int i = 0; i < m; i++) {
                cin >> a >> b;
                addedge(a, b);
                addedge(b, a);
            }
            dfs1(1, 0, 0);
            dfs2(1, 1);
            memset(ad, 0, sizeof(ad));
            for (int i = 1; i <= n; i++) {
                add(p[i], w[i]);
                add(p[i] + 1, -w[i]);
            }
            for (int i = 0; i < ps; i++) {
                char op;
                cin >> op;
                if (op == 'Q') {
                    cin >> a;
                    cout << query(p[a]) << endl;
                } else {
                    cin >> a >> b >> c;
                    if (op == 'D')
                        c = -c;
                    update(a, b, c);
                }
            }
        }
        return 0;
    }
    

    利用树链求LCA

    这个部分参考了Peocco学长,十分感谢

    在这道经典题中,求了LCA,但为什么树剖就可以求LCA呢?

    树剖可以单次 (O(log n))! 地求LCA,且常数较小。假如我们要求两个节点的LCA,如果它们在同一条链上,那直接输出深度较小的那个节点就可以了。

    否则,LCA要么在链头深度较小的那条链上,要么就是两个链头的父节点的LCA,但绝不可能在链头深度较大的那条链上[1]。所以我们可以直接把链头深度较大的节点用其链头的父节点代替,然后继续求它与另一者的LCA。

    由于在链上我们可以 (O(1)) 地跳转,每条链间由轻边连接,而经过轻边的次数又不超过 [公式] ,所以我们实现了 (O(log n)) 的LCA查询。

    int lca(int a, int b) {
        while (top[a] != top[b]) {
            if (dep[top[a]] > dep[top[b]])
                a = fa[top[a]];
            else
                b = fa[top[b]];
        }
        return (dep[a] > dep[b] ? b : a);
    }
    

    结合数据结构

    在进行了树链剖分后,我们便可以配合线段树等数据结构维护树上的信息,这需要我们改一下第二次 DFS 的代码,我们用dfsn数组记录每个点的dfs序,用madfsn数组记录每棵子树的最大dfs序:(这里有点像连通图的知识了)

    // 需要先把根节点的top初始化为自身
    int cnt;
    void dfs2(int p) {
        madfsn[p] = dfsn[p] = ++cnt;
        if (hson[p] != 0) {
            top[hson[p]] = top[p];
            dfs2(hson[p]);
            madfsn[p] = max(madfsn[p], madfsn[hson[p]]);
        }
        for (auto q : edges[p])
            if (!top[q]) {
                top[q] = q;
                dfs2(q);
                madfsn[p] = max(madfsn[p], madfsn[q]);
            }
    }
    

    注意到,每棵子树的dfs序都是连续的,且根节点dfs序最小;而且,如果我们优先遍历重子节点,那么同一条链上的节点的dfs序也是连续的,且链头节点dfs序最小

    连通树(雾)

    所以就可以用线段树等数据结构维护区间信息(以点权的和为例),例如路径修改(类似于求LCA的过程):

    void update_path(int x, int y, int z) {
        while (top[x] != top[y]) {
            if (dep[top[x]] > dep[top[y]]) {
                update(dfsn[top[x]], dfsn[x], z);
                x = fa[top[x]];
            } else {
                update(dfsn[top[y]], dfsn[y], z);
                y = fa[top[y]];
            }
        }
        if (dep[x] > dep[y])
            update(dfsn[y], dfsn[x], z);
        else
            update(dfsn[x], dfsn[y], z);
    }
    

    路径查询:

    int query_path(int x, int y) {
        int ans = 0;
        while (top[x] != top[y]) {
            if (dep[top[x]] > dep[top[y]]) {
                ans += query(dfsn[top[x]], dfsn[x]);
                x = fa[top[x]];
            } else {
                ans += query(dfsn[top[y]], dfsn[y]);
                y = fa[top[y]];
            }
        }
        if (dep[x] > dep[y])
            ans += query(dfsn[y], dfsn[x]);
        else
            ans += query(dfsn[x], dfsn[y]);
        return ans;
    }
    

    子树修改(更新):

    void update_subtree(int x, int z){
        update(dfsn[x], madfsn[x], z);
    }
    

    子树查询:

    int query_subtree(int x){
        return query(dfsn[x], madfsn[x]);
    }
    

    需要注意,建线段树的时候不是按节点编号建,而是按dfs序建,类似这样:

    for (int i = 1; i <= n; ++i)
        B[i] = read();
    // ...
    for (int i = 1; i <= n; ++i)
        A[dfsn[i]] = B[i];
    build();
    

    当然,不仅可以用线段树维护,有些题也可以使用珂朵莉树等数据结构(要求数据不卡珂朵莉树,如这道)。此外,如果需要维护的是边权而不是点权,把每条边的边权下放到深度较深的那个节点处即可,但是查询、修改的时候要注意略过最后一个点。

    写在最后:

    OI wiki上有一些推荐做的列题,但每个都需要比较多的时间+耐心去完成,所以这里推荐几个必做的题:

    SPOJ QTREE – Query on a tree (树链剖分):千千dalao的题解报告

    HDU 3966 Aragorn’s Story (树链剖分):建议先看一遍我的解法再独立完成。

    参考

    洛谷日报:https://zhuanlan.zhihu.com/p/41082337

    OI wiki:https://oi-wiki.org/graph/hld/

    Pecco学长:https://www.zhihu.com/people/one-seventh

    千千:https://www.dreamwings.cn/hdu3966/4798.html


    1. 设top[a]的深度≤top[b]的深度,且c=lca(a,b)在b所在的链上;那么c是a和b的祖先且c的深度≥top[b]的深度,那么c的深度≥top[a]的深度。c是a的祖先,top[a]也是a的祖先,c的深度大于等于top[a],那c必然在连接top[a]和a的这条链上,与前提矛盾 ↩︎

  • 相关阅读:
    LeetCode153 Find Minimum in Rotated Sorted Array. LeetCode162 Find Peak Element
    LeetCode208 Implement Trie (Prefix Tree). LeetCode211 Add and Search Word
    LeetCode172 Factorial Trailing Zeroes. LeetCode258 Add Digits. LeetCode268 Missing Number
    LeetCode191 Number of 1 Bits. LeetCode231 Power of Two. LeetCode342 Power of Four
    LeetCode225 Implement Stack using Queues
    LeetCode150 Evaluate Reverse Polish Notation
    LeetCode125 Valid Palindrome
    LeetCode128 Longest Consecutive Sequence
    LeetCode124 Binary Tree Maximum Path Sum
    LeetCode123 Best Time to Buy and Sell Stock III
  • 原文地址:https://www.cnblogs.com/RioTian/p/14063710.html
Copyright © 2011-2022 走看看