zoukankan      html  css  js  c++  java
  • 树链剖分

    前置芝士

    dfs序,线段树


    正文

    树链剖分就是通过划分轻重边将树分割成许多链,然后利用数据结构(线段树)来维护这些链

    使得在树上可以用非常优秀的复杂度去遍历一些信息

    (本质上是一种优化暴力(就像LCA)(其实所有数据结构都是优化的暴力))


     

     

    看一个模板题

    首先明确的概念

    重儿子:父亲节点的所有儿子中子树结点数目最多(size最大)的结点;(节点数目包括自身)

    轻儿子:父亲节点中除了重儿子以外的儿子;

    重边:父亲结点和重儿子连成的边;

    轻边:父亲节点和轻儿子连成的边;

    重链:由多条重边连接而成的路径;

    轻链:由多条轻边连接而成的路径;

    比如上面这幅图中,用黑线连接的结点都是重结点,其余均是轻结点,

    2-11就是重链,2-5就是轻链,用红点标记的就是该结点所在重链的起点,也就是下文提到的top结点,

    还有每条边的值其实是进行dfs时的执行序号。

    树链剖分的思路

    将一棵每个节点的儿子按照儿子大小划分成重儿子和轻儿子(其他儿子),将树划分成一条条链(重链和轻链)

    利用dfs序,将同一个链上的点放在一起,在建出线段树

    使得在调用两点的简单路径时,可以一跳跳过多个节点(类比LCA思考)

    从而达到减小复杂度的目的

    如何实现

    有不理解的地方,手模是个不错的选择

    1、先跑第一遍DFS(初始化)

    每遍历到一个点,让siz为1,记录父亲与深度

    然后回溯的时候加上其子树的点的大小

    并顺便在遍历到的子树中挑出重儿子

     1 void dfs(int x, int fa){//
     2         siz[x] = 1, fath[x] = fa, dep[x] = dep[fa] + 1;//确定以x为根的子树的大小,父亲,深度 
     3         //cout<<cnt<<"lzx"<<x<<" "<<fa<<endl;
     4         for(int i = head[x]; i; i = e[i].nxt){//类似于lca初始化的遍历 
     5             int v = e[i].to;
     6             if(v == fa) continue;
     7             dfs(v, x);
     8             siz[x] += siz[v];//回溯的时候更新子树大小 
     9             if(siz[son[x]] < siz[v]) son[x] = v;//挑出重儿子 
    10         } 
    11     }

    2、在跑一遍DFS(分链)

    确定dfs序,并把dfs序所对应的元素用pre数组存起来

    注意遍历顺序,因为开始我们提到划分重链,所以我们要优先遍历重儿子,并把链顶元素也传下去(先遍历重儿子感觉珂以使复杂度最优)

    遍历完重儿子后,再遍历其他儿子,并新开一条链

     1 void dfs2(int x, int tp){//分链,tp表示该链的顶端 
     2         top[x] = tp, dfn[x] = ++cnt, pre[cnt] = x;//确定x节点的链的顶端是tp,x的dfs序及反dfs序 
     3         if(son[x]) dfs2(son[x], tp);//为了使重链的dfn在一起,要先遍历重儿子 
     4         for(int i = head[x]; i; i = e[i].nxt){
     5             int v = e[i].to;
     6             if(v == fath[x] || son[x] == v) continue;//如果一个点的fa等于自己或者下一个点是它的重儿子就跳过
     7             //(如果是重儿子的话应该在以前就已经遍历了,所以还有防止在遍历一遍的作用 
     8             dfs2(v, v);//新开一条链 
     9         }
    10     }

    3、数据维护

    我们不难发现,每个重链的dfs序是连在一起的,那么我们是不是可以考虑用线段树来维护它,因为线段树刚好可以维护一段连续的区间

    线段树板中板

     1 #define lson i << 1
     2     #define rson i << 1 | 1
     3     struct Tree{//和,懒标记,长度 
     4         int sum, lazy, len;
     5     }tree[MAXN << 2];
     6     void push_up(int i){//上传标记 
     7         tree[i].sum = (tree[lson].sum + tree[rson].sum) % p;
     8         return ;
     9     }
    10     void build(int i, int l , int r){//建树 
    11         tree[i].lazy = 0, tree[i].len = r - l + 1;
    12         if(l == r) {    
    13             tree[i].sum = a[pre[l]] % p;
    14             return ;
    15         }
    16         int mid = l + r >> 1;
    17         build(lson, l, mid), build(rson, mid + 1, r);
    18         push_up(i);
    19         return ;
    20     }
    21     void pushdown(int i){//下传懒标记 
    22         if(tree[i].lazy){
    23             tree[lson].lazy = (tree[lson].lazy + tree[i].lazy) % p;
    24             tree[rson].lazy = (tree[rson].lazy + tree[i].lazy) % p;
    25             tree[lson].sum = (tree[lson].sum + tree[i].lazy * tree[lson].len) % p;
    26             tree[rson].sum = (tree[rson].sum + tree[i].lazy * tree[rson].len) % p;
    27             tree[i].lazy = 0;
    28         }
    29         return ;
    30     }
    31     void add(int i, int l, int r, int L, int R, int k){
    32     //lr表示遍历到的区间,LR表示查询到的区间 
    33         if(L <= l && r <= R) {
    34             tree[i].sum = (tree[i].sum + (k * tree[i].len) % p) % p;
    35             tree[i].lazy += k;
    36             return ;
    37         }
    38         //cout<<l<<" "<<R << " "<< r<< " "<<L<<"lkp"<<endl;
    39         if(l > R || r < L) return ;
    40         pushdown(i);
    41         int mid = (l + r) >> 1;
    42         if(L <= mid) add(lson, l, mid, L, R, k);
    43         if(R > mid) add(rson, mid + 1, r, L, R, k);
    44         push_up(i);
    45         return ;
    46     }
    47     int get(int i, int l, int r, int L, int R){
    48         int sum = 0;
    49         if(L <= l && r <= R) {
    50             return tree[i].sum % p;
    51         }
    52         if(l > R || r < L) return 0;
    53         pushdown(i);
    54         int mid = (l + r) >> 1;
    55         if(mid >= L) sum = (sum + get(lson, l, mid, L, R)) % p;
    56         if(mid < R) sum = (sum + get(rson, mid + 1, r, L, R)) % p;
    57         return sum % p;
    58     }
    View Code

    那么怎么更改信息呢

    (更改方式有点像倍增求LCA,珂以类比理解)

    如果两个元素不在同一条链上,

    将链顶深的元素一直向上跳,并在线段树中进行修改(提取)信息的操作

    如果两个元素在同一条链上,直接进行修改(提取)信息的操作

     1 void change(int x, int y, int k){
     2         while (top[x] != top[y]){//如果两个点的链顶不相同(感觉和LCA的处理有点类似 
     3             if(dep[top[x]] < dep[top[y]]) swap(x, y); 
     4             Seg::add(1, 1, n, dfn[top[x]], dfn[x], k);//先改变深度浅的 
     5             x = fath[top[x]];//向上跳到链顶的父亲 
     6         }
     7         if(dfn[x] > dfn[y]) swap(x, y);//最后肯定是在一条链上 
     8         Seg::add(1, 1, n, dfn[x], dfn[y], k);
     9         return ;
    10     }
    11     int ask(int x, int y){
    12         int ans = 0;
    13         while(top[x] != top[y]){//道理和change函数类似 
    14             if(dep[top[x]] < dep[top[y]]) swap(x, y);//先跳深度深度 
    15             ans = (ans + Seg::get(1, 1, n, dfn[top[x]], dfn[x])) % p;
    16             x = fath[top[x]];
    17         }
    18         if(dfn[x] > dfn[y]) swap(x, y);
    19         ans = (ans + Seg::get(1, 1, n, dfn[x], dfn[y])) % p;
    20         return ans % p;
    21     }

    例题的AC代码

    namespace相当于把一部分函数进行组合包装,珂以有效区分函数作用,并避免重变量名

    调用的时候和std类似,用***::即可

      1 /*
      2 Work by: Suzt_ilymics
      3 Knowledge: 树链剖分 
      4 Time: O(nlog^2n)
      5 */
      6 #include<iostream>
      7 #include<cstdio>
      8 #define int long long
      9 using namespace std;
     10 const int MAXN = 1e5+5;
     11 int n, m, r, p;
     12 int a[MAXN], pre[MAXN], siz[MAXN], son[MAXN], dep[MAXN], fath[MAXN], top[MAXN], dfn[MAXN];
     13 
     14 int read(){//因一个逗号写挂了的快读 
     15     /*int s=0,w=1;
     16        char ch=getchar();
     17       while(ch<'0'||ch>'9'){if(ch=='-')w=-1;ch=getchar();}
     18        while(ch>='0'&&ch<='9') s=s*10+ch-'0',ch=getchar();
     19        return s*w;
     20    */
     21     int s = 0, w = 1;
     22     char ch = getchar();
     23     //while(ch<'0'||ch>'9'){if(ch=='-')w=-1;ch=getchar();}
     24     while(ch<'0'||ch>'9'){if(ch=='-')w=-1;ch=getchar();}
     25     while(ch >= '0' && ch <= '9') 
     26     s = s * 10 + ch - '0', ch = getchar();
     27     return s * w;
     28 }
     29 
     30 namespace Seg{//线段树板中板 
     31     #define lson i << 1
     32     #define rson i << 1 | 1
     33     struct Tree{//和,懒标记,长度 
     34         int sum, lazy, len;
     35     }tree[MAXN << 2];
     36     void push_up(int i){//上传标记 
     37         tree[i].sum = (tree[lson].sum + tree[rson].sum) % p;
     38         return ;
     39     }
     40     void build(int i, int l , int r){//建树 
     41         tree[i].lazy = 0, tree[i].len = r - l + 1;
     42         if(l == r) {    
     43             tree[i].sum = a[pre[l]] % p;
     44             return ;
     45         }
     46         int mid = l + r >> 1;
     47         build(lson, l, mid), build(rson, mid + 1, r);
     48         push_up(i);
     49         return ;
     50     }
     51     void pushdown(int i){//下传懒标记 
     52         if(tree[i].lazy){
     53             tree[lson].lazy = (tree[lson].lazy + tree[i].lazy) % p;
     54             tree[rson].lazy = (tree[rson].lazy + tree[i].lazy) % p;
     55             tree[lson].sum = (tree[lson].sum + tree[i].lazy * tree[lson].len) % p;
     56             tree[rson].sum = (tree[rson].sum + tree[i].lazy * tree[rson].len) % p;
     57             tree[i].lazy = 0;
     58         }
     59         return ;
     60     }
     61     void add(int i, int l, int r, int L, int R, int k){
     62     //lr表示遍历到的区间,LR表示查询到的区间 
     63         if(L <= l && r <= R) {
     64             tree[i].sum = (tree[i].sum + (k * tree[i].len) % p) % p;
     65             tree[i].lazy += k;
     66             return ;
     67         }
     68         //cout<<l<<" "<<R << " "<< r<< " "<<L<<"lkp"<<endl;
     69         if(l > R || r < L) return ;
     70         pushdown(i);
     71         int mid = (l + r) >> 1;
     72         if(L <= mid) add(lson, l, mid, L, R, k);
     73         if(R > mid) add(rson, mid + 1, r, L, R, k);
     74         push_up(i);
     75         return ;
     76     }
     77     int get(int i, int l, int r, int L, int R){
     78         int sum = 0;
     79         if(L <= l && r <= R) {
     80             return tree[i].sum % p;
     81         }
     82         if(l > R || r < L) return 0;
     83         pushdown(i);
     84         int mid = (l + r) >> 1;
     85         if(mid >= L) sum = (sum + get(lson, l, mid, L, R)) % p;
     86         if(mid < R) sum = (sum + get(rson, mid + 1, r, L, R)) % p;
     87         return sum % p;
     88     }
     89 }
     90 
     91 namespace Cut{
     92     int num_edge = 0, cnt = 0, head[MAXN << 1] = {0};
     93     struct edge{
     94         int nxt, to, from;
     95     }e[MAXN << 1];
     96     void add(int from, int to){ 
     97         e[++num_edge].to = to;
     98         e[num_edge].from = from;
     99         e[num_edge].nxt = head[from];
    100         head[from] = num_edge;
    101     }
    102     void dfs(int x, int fa){//
    103         siz[x] = 1, fath[x] = fa, dep[x] = dep[fa] + 1;//确定以x为根的子树的大小,父亲,深度 
    104         //cout<<cnt<<"lzx"<<x<<" "<<fa<<endl;
    105         for(int i = head[x]; i; i = e[i].nxt){//类似于lca初始化的遍历 
    106             int v = e[i].to;
    107             if(v == fa) continue;
    108             dfs(v, x);
    109             siz[x] += siz[v];//回溯的时候更新子树大小 
    110             if(siz[son[x]] < siz[v]) son[x] = v;//挑出重儿子 
    111         } 
    112     }
    113     //引入重链这个概念会使分的链最少,复杂度更优秀 
    114     void dfs2(int x, int tp){//分链,tp表示该链的顶端 
    115         top[x] = tp, dfn[x] = ++cnt, pre[cnt] = x;//确定x节点的链的顶端是tp,x的dfs序及反dfs序 
    116         if(son[x]) dfs2(son[x], tp);//为了使重链的dfn在一起,要先遍历重儿子 
    117         for(int i = head[x]; i; i = e[i].nxt){
    118             int v = e[i].to;
    119             if(v == fath[x] || son[x] == v) continue;//如果一个点的fa等于自己或者下一个点是它的重儿子就跳过
    120             //(如果是重儿子的话应该在以前就已经遍历了,所以还有防止在遍历一遍的作用 
    121             dfs2(v, v);//新开一条链 
    122         }
    123     }
    124     void change(int x, int y, int k){
    125         while (top[x] != top[y]){//如果两个点的链顶不相同(感觉和LCA的处理有点类似 
    126             if(dep[top[x]] < dep[top[y]]) swap(x, y); 
    127             Seg::add(1, 1, n, dfn[top[x]], dfn[x], k);//先改变深度浅的 
    128             x = fath[top[x]];//向上跳到链顶的父亲 
    129         }
    130         if(dfn[x] > dfn[y]) swap(x, y);//最后肯定是在一条链上 
    131         Seg::add(1, 1, n, dfn[x], dfn[y], k);
    132         return ;
    133     }
    134     int ask(int x, int y){
    135         int ans = 0;
    136         while(top[x] != top[y]){//道理和change函数类似 
    137             if(dep[top[x]] < dep[top[y]]) swap(x, y);//先跳深度深度 
    138             ans = (ans + Seg::get(1, 1, n, dfn[top[x]], dfn[x])) % p;
    139             x = fath[top[x]];
    140         }
    141         if(dfn[x] > dfn[y]) swap(x, y);
    142         ans = (ans + Seg::get(1, 1, n, dfn[x], dfn[y])) % p;
    143         return ans % p;
    144     }
    145 }
    146 
    147 signed main()
    148 {
    149     //输入 
    150     n = read(), m = read(), r = read(), p = read();
    151     for(int i = 1; i <= n; ++i) a[i] = read();
    152     for(int i = 1, u, v; i <= n - 1; ++i) {
    153         u = read(), v = read();
    154     //cout<<"bilibili";
    155         Cut::add(u, v), Cut::add(v, u);
    156     }
    157     //for(int i = 1; i <= Cut::num_edge; ++i)    printf("%d %dwzd
    ", Cut::e[i].from, Cut::e[i].to);
    158     //初始化 
    159     Cut::dfs(r,0), Cut::dfs2(r, r), Seg::build(1, 1, n);
    160     //操作 
    161     for(int i = 1, opt, x, y, k; i <= m; ++i){
    162         opt = read();
    163         if(opt == 1){
    164             x = read(), y = read(), k = read();
    165             Cut::change(x, y, k);
    166         }
    167         if(opt == 2){
    168             x = read(), y = read();
    169             printf("%lld
    ", Cut::ask(x, y));
    170         }
    171         if(opt == 3){
    172             x = read(), k = read();
    173             //cout<<dfn[x]<<" "<<siz[x]<<"zsf"<<endl;
    174             Seg::add(1, 1, n, dfn[x], dfn[x] + siz[x] - 1, k);
    175         }
    176         if(opt == 4){
    177             x = read(); 
    178             printf("%lld
    ", Seg::get(1, 1, n, dfn[x], dfn[x] + siz[x] - 1));
    179         }
    180         
    181     }
    182     return 0;
    183 }

     [ZJOI2008]树的统计

    一个维护最大值的例题

    自己犯得**错误:

    看清数据范围,提交时把检验用的cout删掉,max在push_up的时候只需要取它两个儿子的最大值

      1 /*
      2 Work by: Suzt_ilymics
      3 Knowledge: 树链剖分 
      4 Time: O(nlog^2n)
      5 */
      6 #include<iostream>
      7 #include<cstdio>
      8 #include<string>
      9 #include<cstdio>
     10 #define int long long
     11 using namespace std;
     12 const int inf = -1000000000;
     13 const int MAXN = 3e4+5;
     14 int n, m;
     15 string s;
     16 int a[MAXN], pre[MAXN], siz[MAXN], son[MAXN], dep[MAXN], fath[MAXN], top[MAXN], dfn[MAXN];
     17 
     18 int read(){
     19     int s = 0, w = 1;
     20     char ch = getchar();
     21     while(ch<'0'||ch>'9'){if(ch=='-')w=-1;ch=getchar();}
     22     while(ch >= '0' && ch <= '9') 
     23     s = s * 10 + ch - '0', ch = getchar();
     24     return s * w;
     25 }
     26 
     27 namespace Seg{
     28     #define lson i << 1
     29     #define rson i << 1 | 1
     30     struct Tree{
     31         int sum, lazy, len, max;
     32     }tree[MAXN << 2];
     33     void push_up(int i){
     34         tree[i].sum = tree[lson].sum + tree[rson].sum;
     35         tree[i].max = max(tree[lson].max, tree[rson].max);
     36         return ;
     37     }
     38     void build(int i, int l , int r){
     39         tree[i].lazy = 0, tree[i].len = r - l + 1;
     40         if(l == r) {    
     41             tree[i].sum = a[pre[l]];
     42             tree[i].max = a[pre[l]];
     43             return ;
     44         }
     45         int mid = (l + r) >> 1;
     46         build(lson, l, mid), build(rson, mid + 1, r);
     47         push_up(i);
     48         return ;
     49     }
     50     void add(int i, int l, int r, int L, int R, int k){
     51         if(L <= l && r <= R) {
     52             tree[i].sum = k;
     53             tree[i].max = k;
     54             return ;
     55         }
     56         if(l > R || r < L) return ;
     57         int mid = (l + r) >> 1;
     58         if(L <= mid) add(lson, l, mid, L, R, k);
     59         if(R > mid) add(rson, mid + 1, r, L, R, k);
     60         push_up(i);
     61         return ;
     62     }
     63     int get_sum(int i, int l, int r, int L, int R){
     64         int sum = 0;
     65         if(L <= l && r <= R) {
     66             return tree[i].sum;
     67         }
     68         if(l > R || r < L) return 0;
     69         int mid = (l + r) >> 1;
     70         if(mid >= L) sum += get_sum(lson, l, mid, L, R);
     71         if(mid < R) sum += get_sum(rson, mid + 1, r, L, R);
     72         return sum;
     73     }
     74     int get_max(int i, int l, int r, int L, int R){
     75         int maxm = inf;
     76         if(L <= l && r <= R){
     77             return tree[i].max;
     78         }
     79         if(l > R || r < L) return inf;
     80         int mid = (l + r) >> 1;
     81         if(mid >= L) maxm = max (maxm, get_max(lson, l, mid, L, R));
     82         if(mid < R) maxm = max (maxm, get_max(rson, mid + 1, r, L, R));
     83         return maxm;
     84     }
     85 }
     86 
     87 namespace Cut{
     88     int num_edge = 0, cnt = 0, head[MAXN << 1] = {0};
     89     struct edge{
     90         int nxt, to, from;
     91     }e[MAXN << 1];
     92     void add(int from, int to){ 
     93         e[++num_edge].to = to;
     94         e[num_edge].from = from;
     95         e[num_edge].nxt = head[from];
     96         head[from] = num_edge;
     97     }
     98     void dfs(int x, int fa){//
     99         siz[x] = 1, fath[x] = fa, dep[x] = dep[fa] + 1;
    100         for(int i = head[x]; i; i = e[i].nxt){
    101             int v = e[i].to;
    102             if(v == fa) continue;
    103             dfs(v, x);
    104             siz[x] += siz[v];
    105             if(siz[son[x]] < siz[v]) son[x] = v;
    106         } 
    107     }
    108     void dfs2(int x, int tp){
    109         top[x] = tp, dfn[x] = ++cnt, pre[cnt] = x;
    110         if(son[x]) dfs2(son[x], tp);
    111         for(int i = head[x]; i; i = e[i].nxt){
    112             int v = e[i].to;
    113             if(v == fath[x] || son[x] == v) continue;
    114             dfs2(v, v);
    115         }
    116     }
    117     int ask_sum(int x, int y){
    118         int ans = 0;
    119         while(top[x] != top[y]){
    120             if(dep[top[x]] < dep[top[y]]) swap(x, y);
    121             ans += Seg::get_sum(1, 1, n, dfn[top[x]], dfn[x]);
    122             x = fath[top[x]];
    123         }
    124         if(dfn[x] > dfn[y]) swap(x, y);
    125         ans += Seg::get_sum(1, 1, n, dfn[x], dfn[y]);
    126         return ans;
    127     }
    128     int ask_max(int x, int y){
    129         int maxm = inf;
    130         while(top[x] != top[y]){
    131             if(dep[top[x]] < dep[top[y]]) swap(x, y);
    132             maxm = max (maxm, Seg::get_max(1, 1, n, dfn[top[x]], dfn[x]));
    133             x = fath[top[x]];
    134         }
    135         if(dfn[x] > dfn[y]) swap(x, y);
    136         maxm = max (maxm, Seg::get_max(1, 1, n, dfn[x], dfn[y]));
    137         return maxm;
    138     }
    139 }
    140 
    141 signed main()
    142 {
    143     n = read();
    144     for(int i = 1, u, v; i <= n - 1; ++i) {
    145         u = read(), v = read();
    146         Cut::add(u, v), Cut::add(v, u);
    147     }
    148     for(int i = 1; i <= n; ++i) a[i] = read();
    149 
    150     Cut::dfs(1,0), Cut::dfs2(1, 1), Seg::build(1, 1, n);
    151     
    152     m = read();
    153     for(int i = 1, x, y, k; i <= m; ++i){
    154         cin>>s;
    155         if(s[1] == 'M'){//Qmax
    156             x = read(), y = read();
    157             if(x > y) swap(x, y);
    158             printf("%lld
    ", Cut::ask_max(x, y));
    159         }
    160         if(s[1] == 'H'){//Change
    161             x = read(), k = read();
    162             Seg::add(1, 1, n, dfn[x], dfn[x], k);
    163         }
    164         if(s[1] == 'S'){//Qsum
    165             x = read(), y = read();
    166             if(x > y) swap(x, y);
    167             printf("%lld
    ", Cut::ask_sum(x, y));
    168         }
    169     }
    170     return 0;
    171 }
    AC代码
  • 相关阅读:
    macOS Sierra 如何打开任何来源
    centos 安装git服务器,配置使用证书登录并你用hook实现代码自动部署
    Linux下修改Mysql的用户(root)的密码
    mysql主从复制
    CentOS7下安装MySQL5.7安装与配置
    gulp安装和使用
    libiconv库的安装和使用
    Android 开发中常见的注意点
    扯一扯 C#委托和事件?策略模式?接口回调?
    Python 学习开篇
  • 原文地址:https://www.cnblogs.com/Silymtics/p/13868056.html
Copyright © 2011-2022 走看看