zoukankan      html  css  js  c++  java
  • 树链剖分模板

    概念

    树剖就是将一棵树暴力拆成几条链,然后对于这样一个序列,我们就可以套上资瓷区间处理的一些东西qwq(比如说线段树,树状数组

    可以解决的问题:

    • 将树从$x$到$y$结点最短路径上所有节点的值都加上$z$
    • 求树从$x$到$y$结点最短路径上所有节点的值之和/最大值
    • 将以$x$为根节点的子树内所有节点值都加上$z$
    • 求以$x$为根节点的子树内所有节点值之和/最大值

    一些概念:

    • 重儿子:父亲节点的所有儿子中子树结点数目最多($size$最大)的结点;
    • 轻儿子:父亲节点中除了重儿子以外的儿子;
    • 重边:父亲结点和重儿子连成的边;
    • 轻边:父亲节点和轻儿子连成的边;
    • 重链:由多条重边连接而成的路径;
    • 轻链:由多条轻边连接而成的路径;

    实现

    一些定义:

    • $f(x)$表示节点$x$在树上的父亲
    • $dep(x)$表示节点$x$在树上的深度
    • $siz(x)$表示节点$x$的子树的节点的个数
    • $son(x)$表示节点$x$的重儿子
    • $top(x)$表示节点$s$所在重链的顶部节点(深度最小)
    • $id(x)$表示节点$x$在线段树中的编号
    • $rk(x)$表示线段树中标号为$x$的节点对应的树上节点的编号

    1、第一次DFS,对于一个点求出它所在的子树的大小、它的重儿子,顺便记录其父节点和深度。

     1 void dfs1(int u, int fa, int depth)  //当前节点、父节点、层次深度
     2 {
     3     //printf("u:%d fa:%d depth:%d
    ", u, fa, depth);
     4     f[u] = fa;
     5     deep[u] = depth;
     6     size[u] = 1;   //这个点本身的size
     7     for(int i = head[u];i;i = edges[i].next)
     8     {
     9         int v = edges[i].to;
    10         if(v == fa)  continue;
    11         dfs1(v, u, depth+1);
    12         size[u] += size[v];   //子节点的size已被处理,用它来更新父节点的size
    13         if(size[v] > size[son[u]])  son[u] = v;    //选取size最大的作为重儿子
    14     }
    15 }

    2、第二次DFS,连接重链,同时标记每个节点的DFS序。为了用数据结构来维护重链,我们在DFS时保证一条重链上的节点DFS序连续。一个节点的子树内DFS序也连续。

     1 void dfs2(int u, int t)  //当前节点、重链顶端
     2 {
     3     printf("u:%d t:%d
    ", u, t);
     4     top[u] = t;
     5     id[u] = ++cnt;   //标记dfs序
     6     rk[cnt] = u;     //序号cnt对应节点u
     7     if(!son[u])  return;   //没有儿子?
     8     dfs2(son[u], t);  //我们选择优先进入重儿子来保证一条重链上各个节点dfs序连续
     9 
    10     for(int i = head[u];i;i = edges[i].next)
    11     {
    12         int v = edges[i].to;
    13         if(v != son[u] && v != f[u])  dfs2(v, v);  //这个点位于轻链顶端,那么它的top必然为它本身
    14     }
    15 }

    3、两遍DFS就是树链剖分的主要处理,通过dfs我们已经保证一条重链上各个节点的dfs序连续,那么可以想到,我们可以通过数据结构来维护(以线段树为例)来维护一条重链的信息。

    维护和

     1 ll querysum(int x, int y)
     2 {
     3     int fx = top[x], fy = top[y];
     4     ll ans = 0;
     5     while(fx != fy)   //当两者不在同一条重链上
     6     {
     7         if(deep[fx] >= deep[fy])
     8         {
     9             ans += st.query2(1, 1, n, 0, id[fx], id[x]);   //线段树区间求和,计算这条重链的贡献
    10             x = f[fx]; fx = top[x];
    11         }
    12         else
    13         {
    14             ans += st.query2(1, 1, n, 0, id[fy], id[y]);
    15             y = f[fy]; fy = top[y];
    16         }
    17     }
    18 
    19     //循环结束,两点位于同一重链上,但两者不一定为同一点,所以还要加上这两点之间的贡献
    20     if(id[x] <= id[y])
    21     {
    22         ans += st.query2(1, 1, n, 0, id[x], id[y]);
    23     }
    24     else
    25     {
    26         ans += st.query2(1, 1, n, 0, id[y], id[x]);
    27     }
    28     return ans;
    29 }

    维护最大值

     1 ll querymax(int x, int y)
     2 {
     3     int fx = top[x], fy = top[y];
     4     ll ans = -INF;
     5     while(fx != fy)   //当两者不在同一条重链上
     6     {
     7         if(deep[fx] >= deep[fy])
     8         {
     9             ans = max(ans, st.query1(1, 1, n, 0, id[fx], id[x]));   //线段树区间求和,计算这条重链的贡献
    10             x = f[fx]; fx = top[x];
    11         }
    12         else
    13         {
    14             ans = max(ans, st.query1(1, 1, n, 0, id[fy], id[y]));
    15             y = f[fy]; fy = top[y];
    16         }
    17     }
    18     //循环结束,两点位于同一重链上,但两者不一定为同一点,所以还要加上这两点之间的贡献
    19     if(id[x] <= id[y])  ans = max(ans, st.query1(1, 1, n, 0, id[x], id[y]));
    20     else ans = max(ans, st.query1(1, 1, n, 0, id[y], id[x]));
    21     return ans;
    22 }

    时间复杂度

    对于每次询问,最多经过$O(log n)$条重链,每条重链上线段树的复杂度为$O(log n)$,此时总的时间复杂度为$O(nlogn+q{log}^2n)$。实际上重链个数很难达到$O(log n)$(可以用完全二叉树卡满),所以树剖在一般情况下常数较小。

    完整的代码

      1 #include<bits/stdc++.h>
      2 using namespace std;
      3 
      4 typedef long long ll;
      5 #define lc o <<1
      6 #define rc o <<1 | 1
      7 const int INF = 0x3f3f3f3f;
      8 const int maxn = 100000 + 10;
      9 struct Edge
     10 {
     11     int to, next;
     12 }edges[2*maxn];
     13 int head[maxn];
     14 int cur, f[maxn], deep[maxn], size[maxn], son[maxn], rk[maxn], id[maxn], top[maxn], cnt;
     15 int n, root, qcnt, w[maxn];
     16 
     17 inline void addedge(int u, int v)
     18 {
     19     ++cur;
     20     edges[cur].next = head[u];
     21     head[u] = cur;
     22     edges[cur].to = v;
     23 }
     24 
     25 struct SegTree{
     26     ll sum[maxn << 2], maxv[maxn << 2], addv[maxn << 2];
     27     void build(int o, int l, int r)
     28     {
     29         if(l == r)
     30         {
     31             sum[o] = maxv[o] = w[rk[l]];
     32         }
     33         else
     34         {
     35             int mid = (l + r) >> 1;
     36             build(lc, l, mid);
     37             build(rc, mid+1, r);
     38             sum[o] = sum[lc] + sum[rc];
     39             maxv[o] = max(maxv[lc], maxv[rc]);
     40         }
     41     }
     42 
     43     void maintain(int o, int l, int r)
     44     {
     45         if(l == r)  //如果是叶子结点
     46         {
     47             maxv[o] = w[rk[l]];
     48             sum[o] = w[rk[l]];
     49         }
     50         else     //如果是非叶子结点
     51         {
     52             maxv[o] = max(maxv[lc], maxv[rc]);
     53             sum[o] = sum[lc] + sum[rc];
     54         }
     55         maxv[o] += addv[o];     //考虑add操作
     56         sum[o] += addv[o] * (r-l+1);
     57     }
     58     //区间修改,[cl,cr] += v;
     59     void update(int o, int l, int r, int cl, int cr, int v)  //
     60     {
     61         //printf("o:%d  l:%d  r:%d
    ", o, l, r);
     62         if(cl <= l && r <= cr)  addv[o] += v;
     63         else
     64         {
     65             int m = l + (r-l) /2;
     66             if(cl <= m)  update(lc, l, m, cl, cr, v);
     67             if(cr > m)  update(rc, m+1, r, cl, cr, v);
     68         }
     69         maintain(o, l, r);
     70     }
     71 
     72     //区间查询1,max{ql,qr}
     73     ll query1(int o, int l,int r, int add, int ql, int qr)
     74     {
     75         //prllf("o:%d l:%d r:%d
    ", o, l, r);
     76         if(ql <= l && r <= qr)  return maxv[o] + add;
     77         else
     78         {
     79             int m = l + (r - l) / 2;
     80             ll ans = -INF;
     81             add += addv[o];
     82             if(ql <= m)  ans = max(ans, query1(lc, l, m, add, ql, qr));
     83             if(qr > m)  ans = max(ans, query1(rc, m+1, r, add, ql, qr));
     84             return ans;
     85         }
     86     }
     87 
     88     //区间查询2,sum{ql,qr}
     89     ll query2(int o, int l,int r, int add, int ql, int qr)
     90     {
     91         //prllf("o:%d l:%d r:%d ql:%d qr:%d
    ", o, l, r, ql, qr);
     92         if(ql <= l && r <= qr)  return sum[o] + add * (r-l+1);
     93         else
     94         {
     95             int m = l + (r - l) / 2;
     96             ll ans = 0;
     97             add += addv[o];
     98             if(ql <= m)  ans += query2(lc, l, m, add, ql, qr);
     99             if(qr > m)  ans += query2(rc, m+1, r, add, ql, qr);
    100             return ans;
    101         }
    102     }
    103 }st;
    104 
    105 void dfs1(int u, int fa, int depth)  //当前节点、父节点、层次深度
    106 {
    107     //printf("u:%d fa:%d depth:%d
    ", u, fa, depth);
    108     f[u] = fa;
    109     deep[u] = depth;
    110     size[u] = 1;   //这个点本身的size
    111     for(int i = head[u];i;i = edges[i].next)
    112     {
    113         int v = edges[i].to;
    114         if(v == fa)  continue;
    115         dfs1(v, u, depth+1);
    116         size[u] += size[v];   //子节点的size已被处理,用它来更新父节点的size
    117         if(size[v] > size[son[u]])  son[u] = v;    //选取size最大的作为重儿子
    118     }
    119 }
    120 
    121 void dfs2(int u, int t)  //当前节点、重链顶端
    122 {
    123     printf("u:%d t:%d
    ", u, t);
    124     top[u] = t;
    125     id[u] = ++cnt;   //标记dfs序
    126     rk[cnt] = u;     //序号cnt对应节点u
    127     if(!son[u])  return;   //没有儿子?
    128     dfs2(son[u], t);  //我们选择优先进入重儿子来保证一条重链上各个节点dfs序连续
    129 
    130     for(int i = head[u];i;i = edges[i].next)
    131     {
    132         int v = edges[i].to;
    133         if(v != son[u] && v != f[u])  dfs2(v, v);  //这个点位于轻链顶端,那么它的top必然为它本身
    134     }
    135 }
    136 
    137 ll querymax(int x, int y)
    138 {
    139     int fx = top[x], fy = top[y];
    140     ll ans = -INF;
    141     while(fx != fy)   //当两者不在同一条重链上
    142     {
    143         if(deep[fx] >= deep[fy])
    144         {
    145             ans = max(ans, st.query1(1, 1, n, 0, id[fx], id[x]));   //线段树区间求和,计算这条重链的贡献
    146             x = f[fx]; fx = top[x];
    147         }
    148         else
    149         {
    150             ans = max(ans, st.query1(1, 1, n, 0, id[fy], id[y]));
    151             y = f[fy]; fy = top[y];
    152         }
    153     }
    154     //循环结束,两点位于同一重链上,但两者不一定为同一点,所以还要加上这两点之间的贡献
    155     if(id[x] <= id[y])  ans = max(ans, st.query1(1, 1, n, 0, id[x], id[y]));
    156     else ans = max(ans, st.query1(1, 1, n, 0, id[y], id[x]));
    157     return ans;
    158 }
    159 
    160 /*修改和查询的原理是一致的,以查询操作为例,其实就是个LCA,不过这里要使用top数组加速,因为top可以直接跳到该重链的起始顶点*/
    161 /*注意,每次循环只能跳一次,并且让结点深的那个跳到top的位置,避免两者一起跳而插肩而过*/
    162 ll querysum(int x, int y)
    163 {
    164     int fx = top[x], fy = top[y];
    165     ll ans = 0;
    166     while(fx != fy)   //当两者不在同一条重链上
    167     {
    168         if(deep[fx] >= deep[fy])
    169         {
    170             ans += st.query2(1, 1, n, 0, id[fx], id[x]);   //线段树区间求和,计算这条重链的贡献
    171             x = f[fx]; fx = top[x];
    172         }
    173         else
    174         {
    175             ans += st.query2(1, 1, n, 0, id[fy], id[y]);
    176             y = f[fy]; fy = top[y];
    177         }
    178     }
    179 
    180     //循环结束,两点位于同一重链上,但两者不一定为同一点,所以还要加上这两点之间的贡献
    181     if(id[x] <= id[y])
    182     {
    183         ans += st.query2(1, 1, n, 0, id[x], id[y]);
    184     }
    185     else
    186     {
    187         ans += st.query2(1, 1, n, 0, id[y], id[x]);
    188     }
    189     return ans;
    190 }
    191 
    192 void update_add(int x, int y, int add)
    193 {
    194     int fx = top[x], fy = top[y];
    195     while(fx != fy)   //当两者不在同一条重链上
    196     {
    197         if(deep[fx] >= deep[fy])
    198         {
    199             st.update(1, 1, n, id[fx], id[x], add);
    200             x = f[fx]; fx = top[x];
    201         }
    202         else
    203         {
    204             st.update(1, 1, n, id[fy], id[y], add);
    205             y = f[fy]; fy = top[y];
    206         }
    207     }
    208     //循环结束,两点位于同一重链上,但两者不一定为同一点,所以还要加上这两点之间的贡献
    209     if(id[x] <= id[y])  st.update(1, 1, n, id[x], id[y], add);
    210     else  st.update(1, 1, n, id[y], id[x], add);
    211 }
    212 
    213 int main()
    214 {
    215     scanf("%d%d%d", &n, &root, &qcnt);
    216     for(int i = 1;i <= n;i++)  scanf("%d", &w[i]);
    217     for(int i = 1;i < n;i++)
    218     {
    219         int u, v;
    220         scanf("%d%d", &u, &v);
    221         addedge(u, v);
    222         addedge(v, u);
    223     }
    224     dfs1(root, -1, 1);
    225     dfs2(root, root);
    226 
    227     for(int i = 1;i <= n;i++)  printf("%d  ", id[i]);
    228     printf("
    ");
    229     for(int i = 1;i <= n;i++)  printf("%d  ", rk[i]);
    230     printf("
    ");
    231 
    232     st.build(1, 1, n);
    233 
    234     while(qcnt--)
    235     {
    236         int op;
    237         scanf("%d", &op);
    238         if(op == 1)
    239         {
    240             int u, v, add;
    241             scanf("%d%d%d", &u, &v, &add);
    242             update_add(u, v,  add);
    243         }
    244         else if(op == 2)
    245         {
    246             int u, v;
    247             scanf("%d%d", &u, &v);
    248             printf("%d
    ", querymax(u, v));
    249         }
    250         else if(op == 3)
    251         {
    252             int u, v;
    253             scanf("%d%d", &u, &v);
    254             printf("%d
    ", querysum(u, v));
    255         }
    256         else if(op == 4)
    257         {
    258             int u, add;
    259             scanf("%d%d", &u, &add);
    260             st.update(1, 1, n, id[u], id[u]+size[u]-1, add);
    261         }
    262         else if(op == 5)
    263         {
    264             int u;
    265             scanf("%d", &u);
    266             printf("%d
    ",st.query1(1, 1, n, 0, id[u], id[u]+size[u]-1));
    267         }
    268         else
    269         {
    270             int u;
    271             scanf("%d", &u);
    272             printf("%d
    ",st.query2(1, 1, n, 0, id[u], id[u]+size[u]-1));
    273         }
    274     }
    275     return 0;
    276 }
    View Code

    参考链接:

    1. https://oi-wiki.org/graph/heavy-light-decomposition/

    2. https://zhuanlan.zhihu.com/p/41082337

    3. https://www.luogu.org/problemnew/solution/P3384

  • 相关阅读:
    python全栈开发从入门到放弃之socket并发编程之协程
    python全栈开发从入门到放弃之socket并发编程多线程GIL
    python全栈开发从入门到放弃之socket并发编程多线程
    python全栈开发从入门到放弃之socket并发编程多进程
    python全栈开发从入门到放弃之socket网络编程基础
    python全栈开发从入门到放弃之异常处理
    python全栈开发从入门到放弃之面向对象反射
    python全栈开发从入门到放弃之面向对象的三大特性
    转:经典ACM算法
    反射在Java Swing中的应用
  • 原文地址:https://www.cnblogs.com/lfri/p/11169231.html
Copyright © 2011-2022 走看看