zoukankan      html  css  js  c++  java
  • 树链剖分详解

    树链剖分模板题

    1,将树从x到y结点最短路径上所有节点的值都加上z

    2,求树从x到y结点最短路径上所有节点的值之和

    3,将x为根节点的子树内所有节点的值加上z

    4,求x为根节点的子树内所有节点值之和

    (以下都基于这个题目展开讲解)

    如果没有操作3和4,这题可以用树上差分和lca解决,也是模板题

    树上差分指路:[https://www.cnblogs.com/gzh-red/p/11185914.html]

    求lca指路:[https://www.cnblogs.com/lsdsjy/p/4071041.html]


    好,进入正题

    树链剖分,顾名思义,就是通过轻重边的划分将树分割成多条链,保证每个点属于且只属于一条链,然后再通过数据结构(树状数组、BST、SPLAY、线段树等)来维护每一条链。(摘自百度百科)

    几个基本概念:
    • 重儿子:当前节点所有子结点中子树节点数(size)最多的结点,为当前节点的重儿子
    • 轻儿子:除了重节点的儿子
    • 重边:连接父亲节点和重儿子的边
    • 轻边:连接父亲节点和轻儿子的边
    • 重链:多条重边连接成的路径(也就是说有多条重边在一条连续的路径上)
    • 轻链:多条轻边组成的路径

    比如这样一棵树

    这样一棵树

    ①的子节点中,④的size最大,所以④是①的重儿子,连接①④的边为重边。

    最终dfs一次这张图应该有这些信息

    ![ZqqA1S](C:UserswawaPicturesCamera RollqqA1S.jpg)

    红圈表示重儿子,加粗边为重边

    这个过程的代码实现

    void dfs1(int u, int fa)
    {
        deg[u] = deg[fa] + 1;//deg为节点深度
        f[u] = fa;           //f[u]为u的父亲节点
        sz[u] = 1;           //sz[u]为u所在子树大小(包括自己)
        for (int i = head[u]; i; i = e[i].nex)
        {
            int ne = e[i].to;
            if (ne == fa) continue;
            dfs1(ne, u);
            sz[u] += sz[ne];
            if (sz[ne] > sz[son[u]]) 
                son[u] = ne;
        }
    }
    

    注意的是:如果u为叶节点,它是没有重儿子的。如果u有多个子结点子树大小相等,随便谁当重儿子都行。


    第二遍dfs要将重边连成重链,保证一条重链上的节点dfs序连续,以便用用数据结构维护(比如线段树肯定是对连续的区间进行维护)同时要处理出重链的链头,也就是假如一条重链深度最小的点为①,然后有②③④,则top[1],top[2],top[3]和top[4]均为1。

    void dfs2(int u, int top_fa)
    {
        xu[u] = ++inde; 
        v[inde] = w[u]; 
        top[u] = top_fa;
        if (!son[u]) return ;//如果为叶节点,返回
        dfs2(son[u], top_fa);//优先走重边
        for (int i = head[u]; i; i = e[i].nex)
        {
            int ne = e[i].to;
            if (ne == f[u] || ne == son[u]) continue;
            dfs2(ne, ne);
        }
    }
    

    剖分的工作就做完了,接下来就是用数据结构维护

    因为一条重链上的节点dfs序连续,那么路径和以及路径修改就可以转化成区间求和以及区间修改,那么可以用线段树维护。

    以查询为例 ,还是用求lca的方式,用top可以直接跳到一条重链的起始位置,让top[x]和top[y]中比较深的节点来跳,直接跳到对应top[]的父节点,可以保证两个节点一定不会擦肩而过,也能保证最好两个点一定能跳到同一条重链上。

    ll query(int rt, int l, int r, int x, int y)
    {
        if (x <= l && r <= y) return sum[rt];
        pushdown(rt, l, r);
        int mid = (l + r) >> 1;
        ll res = 0;
        if (x <= mid) 
            res = (res + query(lson, l, mid, x, y)) % mod;
        if (mid < y) 
            res = (res + query(rson, mid+1, r, x, y)) % mod;
        return res;
    }
    ll qRange(int x, int y)
    {
        ll ans=0;
        while (top[x] != top[y])
        {
            if (deg[top[x]] < deg[top[y]])
                swap(x,y);
            ans = (ans + query(1, 1, n, xu[top[x]], xu[x])) % mod;
            x = f[top[x]];
        }
        if (deg[x] > deg[y])
            swap(x, y);
        ans = (ans + query(1, 1, n, xu[x], xu[y])) % mod;
        return ans;
    }
    

    可以用上面的图跳几对点手动模拟一下,加深理解。

    完整代码

    #include<cstdio>
    #include<cstring>
    #include<iostream>
    #include<sstream>
    #include<algorithm>
    #include<vector>
    #include<queue>
    #include<map>
    #include<cstdlib>
    #include<cmath>
    using namespace std;
    #define re register int
    #define ull unsigned long long
    #define ll long long
    #define inf 0x3f3f3f3f
    #define N 1000010
    #define lson rt<<1
    #define rson rt<<1|1
    #define lowbit(x) (x)&(-(x))
    void FRE(){freopen("subsets.in","r",stdin);freopen("subsets.out","w",stdout);}
    inline ll read()
    {
        ll x=0,f=1;char ch=getchar();
        while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
        while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ll)(ch-'0');ch=getchar();}
        return x*f;
    }
    
    int cnt, head[N], n, m, root, xu[N], deg[N], top[N];
    int f[N],sz[N],son[N],inde;
    ll mod,sum[N],v[N],w[N],tag[N];
    struct node
    {
        int to,nex;
    }e[N*2];
    
    void add(int u, int v)
    {
        e[++cnt].nex = head[u];
        head[u] = cnt;
        e[cnt].to=v;
    }
    
    void dfs1(int u, int fa)
    {
        deg[u] = deg[fa] + 1;//deg为节点深度
        f[u] = fa;           //f[u]为u的父亲节点
        sz[u] = 1;           //sz[u]为u所在子树大小(包括自己)
        for (int i = head[u]; i; i = e[i].nex)
        {
            int ne = e[i].to;
            if (ne == fa) continue;
            dfs1(ne, u);
            sz[u] += sz[ne];
            if (sz[ne] > sz[son[u]]) 
                son[u] = ne;
        }
    }
    
    void dfs2(int u, int top_fa)
    {
        xu[u] = ++inde; 
        v[inde] = w[u]; 
        top[u] = top_fa;
        if (!son[u]) return ;//如果为叶节点,返回
        dfs2(son[u], top_fa);//优先走重边
        for (int i = head[u]; i; i = e[i].nex)
        {
            int ne = e[i].to;
            if (ne == f[u] || ne == son[u]) continue;
            dfs2(ne, ne);
        }
    }
    
    void pushup(int rt)
    {
        sum[rt] = (sum[lson] + sum[rson])%mod;
    }
    
    void pushdown(int rt, int l, int r)
    {
        if (!tag[rt]) return;
        int mid = (l + r) >> 1;
        tag[lson] = (tag[lson] + tag[rt]) % mod;
        tag[rson] = (tag[rson] + tag[rt]) %mod;
        sum[lson] = (sum[lson] + (tag[rt] * (ll)(mid - l + 1)) % mod) % mod;
        sum[rson] = (sum[rson] + (tag[rt] * (ll)(r-mid))  % mod) % mod;
        tag[rt] = 0;
    }
    void build(int rt, int l, int r)
    {
        if (l == r)
        {
            sum[rt] = v[l];
            return ;
        }
        int mid = (l + r)>>1;
        build(lson, l, mid);
        build(rson, mid+1, r);
        pushup(rt);
    }
    
    void update(int rt, int l, int r, int x, int y, int val)
    {
        if (x <= l && r<=y) 
        {
            sum[rt] = (sum[rt] + (r-l+1) * val) % mod;
            tag[rt] = (tag[rt] + val) % mod;
            return ;
        }
        pushdown(rt, l, r);
        int mid = (l + r) >> 1;
        if (x <= mid)
            update(lson, l, mid, x, y, val);
        if (mid < y) 
            update(rson, mid+1, r, x, y, val);
        pushup(rt);
    }
    
    void upRange(int x, int y, int val)
    {
        while (top[x] != top[y])
        {
            if (deg[top[x]] < deg[top[y]])
                swap(x,y);
            update(1, 1, n, xu[top[x]], xu[x], val);
            x = f[top[x]];
        }
        if (deg[x] >deg[y])
            swap(x,y);
        update(1, 1, n, xu[x], xu[y], val);
    }
    
    ll query(int rt, int l, int r, int x, int y)
    {
        if (x <= l && r <= y) return sum[rt];
        pushdown(rt, l, r);
        int mid = (l + r) >> 1;
        ll res = 0;
        if (x <= mid) 
            res = (res + query(lson, l, mid, x, y)) % mod;
        if (mid < y) 
            res = (res + query(rson, mid+1, r, x, y)) % mod;
        return res;
    }
    ll qRange(int x, int y)
    {
        ll ans=0;
        while (top[x] != top[y])
        {
            if (deg[top[x]] < deg[top[y]])
                swap(x,y);
            ans = (ans + query(1, 1, n, xu[top[x]], xu[x])) % mod;
            x = f[top[x]];
        }
        if (deg[x] > deg[y])
            swap(x, y);
        ans = (ans + query(1, 1, n, xu[x], xu[y])) % mod;
        return ans;
    }
    
    int main()
    {
    	n = read(); m = read();
        root = read(); mod = read();
        for (int i = 1; i <= n; i++)
            w[i] = read();
        for (int i = 1; i < n; i++)
        {
            int x = read(), y = read();
            add(x, y), add(y, x);
        }
        dfs1(root, 0); dfs2(root, root);
        build(1, 1, n);
        while (m--)
        {
            int ty = read();
            if (ty == 1)
            {
                int x = read(), y = read();
                ll z = read() % mod;
                upRange(x, y, z);
            }
            if (ty == 2)
            {
                int x = read(), y = read();
                printf("%lld
    ", qRange(x, y));
            }
            if (ty == 3)
            {
                int x = read();
                ll z = read();
                update(1, 1, n, xu[x], xu[x] + sz[x] - 1, z);
            }
            if (ty == 4)
            {
                int x = read();
                printf("%lld
    ", query(1, 1, n, xu[x], xu[x] + sz[x] - 1));
            }
        }
    	return 0;
    }
    

    练习:[NOI2015]软件包管理器

    这个题目就可以理解成区间修改,每次修改的是根节点到x路径

    完整代码

    #include <bits/stdc++.h>
    #define inf 1e18
    #define ll long long
    #define N 1000010
    #define lson rt << 1
    #define rson rt << 1 | 1
    #define mo 998244353
    using namespace std;
    typedef pair<int, int> P;
    inline ll read()
    {
        ll x = 0, f = 1;
        char ch = getchar();
        while (ch < '0' || ch > '9')
        {
            if (ch == '-')
                f = -1;
            ch = getchar();
        }
        while (ch >= '0' && ch <= '9')
            x = x * 10 + ch - '0', ch = getchar();
        return x * f;
    }
    int f[N], sz[N], son[N], head[N], cnt, deg[N], xu[N], len, tag[N], sum[N], top[N], n;
    char ty[20];
    struct node
    {
        int nex, to;
    } e[N * 2];
    void add(int u, int v)
    {
        e[++cnt].nex = head[u];
        head[u] = cnt;
        e[cnt].to = v;
    }
    void dfs1(int u, int fa)
    {
        f[u] = fa;
        sz[u] = 1;
        deg[u] = deg[fa] + 1;
        for (int i = head[u]; i; i = e[i].nex)
        {
            int ne = e[i].to;
            if (ne == fa)
                continue;
            dfs1(ne, u);
            sz[u] += sz[ne];
            if (sz[ne] > sz[son[u]])
                son[u] = ne;
        }
    }
    void dfs2(int u, int fa)
    {
        top[u] = fa;
        xu[u] = ++len;
        if (!son[u])
            return;
        dfs2(son[u], fa);
        for (int i = head[u]; i; i = e[i].nex)
        {
            int ne = e[i].to;
            if (ne == f[u] || ne == son[u])
                continue;
            dfs2(ne, ne);
        }
    }
    void pushdown(int rt, int l, int r)
    {
        if (tag[rt] == -1)
            return;
        int mid = (l + r) >> 1;
        tag[lson] = tag[rt], tag[rson] = tag[rt];
        sum[lson] = (mid - l + 1) * tag[rt], sum[rson] = (r - mid) * tag[rt];
        tag[rt] = -1;
    }
    void pushup(int rt) { sum[rt] = sum[lson] + sum[rson]; }
    void modify(int rt, int l, int r, int x, int y, int val)
    {
        if (x <= l && r <= y)
        {
            sum[rt] = (r - l + 1) * val;
            tag[rt] = val;
            return;
        }
        pushdown(rt, l, r);
        int mid = (l + r) >> 1;
        if (x <= mid)
            modify(lson, l, mid, x, y, val);
        if (y > mid)
            modify(rson, mid + 1, r, x, y, val);
        pushup(rt);
    }
    void upRange(int x, int y)
    {
        while (top[x] != top[y])
        {
            if (deg[top[x]] < deg[top[y]])
                swap(x, y);
            modify(1, 1, n, xu[top[x]], xu[x], 1);
            x = f[top[x]];
        }
        if (deg[x] > deg[y])
            swap(x, y);
        modify(1, 1, n, xu[x], xu[y], 1);
    }
    int query(int rt, int l, int r, int x, int y)
    {
        if (x <= l && r <= y)
            return sum[rt];
        pushdown(rt, l, r);
        int mid = (l + r) >> 1, res = 0;
        if (x <= mid)
            res += query(lson, l, mid, x, y);
        if (y > mid)
            res += query(rson, mid + 1, r, x, y);
        return res;
    }
    int main()
    {
        n = read();
        for (int i = 2; i <= n; i++)
        {
            int x = read();
            add(x + 1, i), add(i, x + 1);
        }
        dfs1(1, 0);
        dfs2(1, 1);
        int Q = read();
        while (Q--)
        {
            scanf("%s", ty);
            int x = read() + 1;
            int tmp = sum[1];
            if (ty[0] == 'i')
            {
                upRange(1, x);
                printf("%d
    ", sum[1] - tmp);
            }
            else
            {
                modify(1, 1, n, xu[x], xu[x] + sz[x] - 1, 0);
                printf("%d
    ", tmp - sum[1]);
            }
        }
        return 0;
    }
    
  • 相关阅读:
    制作 MarkText 的导航栏和动画背景
    某雅互动静态页面
    html5 拖拽及用 js 实现拖拽
    九宫格
    phaser3 入门实例——收集星星游戏
    Flexbox Froggy:练习 Flex 布局的小游戏
    JS30
    ElasticSearch
    JVM
    jstack命令的使用
  • 原文地址:https://www.cnblogs.com/71-111/p/13943126.html
Copyright © 2011-2022 走看看