zoukankan      html  css  js  c++  java
  • 树链剖分-重链剖分

    参考资料:

    https://www.cnblogs.com/ivanovcraft/p/9019090.html
    https://www.cnblogs.com/hanruyun/p/9577500.html


    前置知识:

    树的性质(深度,子树大小,距离,祖先等)
    线段树,树状数组


    树链剖分

    树链剖分可以说是个数据结构,但是更是个存储、操作树上信息的方法。树链剖分的主要思想是把一棵树拆成若干条链,并建立数据结构进行存储、操作,根据拆的方法不同,分为重链剖分和长链剖分。重链剖分用得较多,蒟蒻也更喜欢重链剖分,这篇就讲重链剖分。


    重链剖分

    例题

    为了让萌新对重链剖分有所了解,蒟蒻拿出了例题(几乎包括了重链剖分的所有操作):

    【模板】轻重链剖分

    有一棵 (n) 个节点的树,每个节点有权值,初始时为 (a{n})。有 (m) 个如下操作:
    操作 (1)(1 x y z) 表示将树从 (x)(y) 结点最短路径上所有节点的值都加上 (z)
    操作 (2)(2 x y) 表示求树从 (x)(y) 结点最短路径上所有节点的值之和。
    操作 (3)(3 x z) 表示将以 (x) 为根节点的子树内所有节点值都加上 (z)
    操作 (4)(4 x) 表示求以 (x) 为根节点的子树内所有节点值之和。

    数据范围:(1le n,mle 10^5)


    重点

    看了例题后,先看重链剖分关键概念:

    重儿子:父亲节点的子节点中子树大小最大的。

    轻儿子:父亲节点的子节点中除了重儿子的节点。

    重边:父亲节点与重儿子连成的边。

    轻边:父亲节点与轻儿子连成的边(即除了重边以外的边)。

    重链:由重边构成的极长路径。

    轻链:由轻边构成的极长路径。

    链头:重链深度最小的节点,是个轻儿子

    特殊的,根节点一般被看作轻儿子。

    如下图:

    黑色节点是轻儿子,橙色节点是重儿子。黑色边是轻边,橙色边是重边。底下有橙线的节点单独为一条重链。

    图中一共有 (5) 条重链:

    1. (1 o 3 o 8 o 9)
    2. (2 o 5)
    3. (4)
    4. (6)
    5. (7)

    每个轻儿子下面都有一条重链,所以重链数 (=) 轻儿子数

    这里有一个关键性的性质,使得树链剖分有实际意义:

    任何一个节点到根节点的路径中,包含不超过 (log n) 条重链。

    简证:重链与重链间用轻边连接,所以一个节点到根节点的路径中的重链数 (=) 轻边数 (+1)。轻边连的子树大小较小,所以必然 (< frac 12) 父亲节点子树大小。所以一个节点到根节点的路径中最多有 (log_2n) 条轻边,所以一个节点到根节点的路径中,包含不超过 (log n) 条重链。

    如果重节点优先( exttt{Dfs}) 遍历整棵树,为每个节点标上序号,如下:

    那么同个重链、同个子树的节点序号就是连续的了。

    所以可以维护棵线段树,把修改查询子树转换为区间修改查询,把修改查询最短路径转换为几个区间修改查询(上文提到,任何一个节点到根节点的路径中,包含不超过 (log n) 条重链。所以可以通过修改 (log n) 个区间,达到 (Theta(log n)) 的修改查询最短路径的目的)。

    然后树链剖分的思想就到这里了,如果你理解了,就可以看操作实现了。


    操作实现

    重链剖分的第一步都是两个 ( exttt{Dfs})

    第一个求出每个节点的深度、子树大小、父亲节点、重儿子

    void Dfs1(int x){
    	sz[x]=1,dep[x]=dep[fa[x]]+1;
    	for(int to:e[x])if(to!=fa[x]){
    		fa[to]=x,Dfs1(to),sz[x]+=sz[to];
    		if(sz[to]>sz[son[x]]) son[x]=to;
    	}
    }
    

    第二个找每条重链,并且为节点标上序号

    void Dfs2(int x,int an){
    	tp[x]=an,dfn[x]=++ind,rk[ind]=x;
    	if(son[x]) Dfs2(son[x],an);
    	for(int to:e[x]) if(to!=fa[x]&&to!=son[x]) Dfs2(to,to);
    }
    

    ( exttt{Dfs}) 完后,可以造数据结构了。就拿例题来说,需要造一棵线段树。线段树的任务有维护区间和,区间修改。可以用 ( exttt{lazytag+pushdown})

    lng v[(N<<2)+7],ma[(N<<2)+7];
    #define mid ((l+r)>>1)
    void pushdown(int k,int l,int r){
    	if(!ma[k]) return;
    	(v[k<<1]+=ma[k]*(mid-l+1))%=mod;
    	(v[k<<1|1]+=ma[k]*(r-mid))%=mod;
    	(ma[k<<1]+=ma[k])%=mod;
    	(ma[k<<1|1]+=ma[k])%=mod;
    	ma[k]=0;
    }
    void build(int k=1,int l=1,int r=n){
    	if(l==r){v[k]=a[rk[l]];return;}
    	build(k<<1,l,mid),build(k<<1|1,mid+1,r),v[k]=v[k<<1]+v[k<<1|1];
    }
    void fixson(int x,int y,lng z,int k=1,int l=1,int r=n){ // 修改子树(序号区间)权值
    	if(x<=l&&r<=y){(ma[k]+=z)%=mod,v[k]+=z*(r-l+1);return;}
    	pushdown(k,l,r);
    	if(mid>=x)fixson(x,y,z,k<<1,l,mid);
    	if(mid<y)fixson(x,y,z,k<<1|1,mid+1,r);
    	v[k]=v[k<<1]+v[k<<1|1];
    }
    lng sumson(int x,int y,int k=1,int l=1,int r=n){ // 求子树(序号区间)权值和
    	if(x<=l&&r<=y) return v[k];
    	lng res=0;
    	pushdown(k,l,r);
    	if(mid>=x) res+=sumson(x,y,k<<1,l,mid);
    	if(mid<y)res+=sumson(x,y,k<<1|1,mid+1,r);
    	return res%mod;
    }
    

    至于修改查询最短路径,依赖于序号区间修改查询。依次遍历最短路径上的每条重链,整个操作过程如同求 ( exttt{LCA})

    void fixdis(int x,int y,int z){ //修改最短路径权值
    	for(;tp[x]!=tp[y];fixson(dfn[tp[x]],dfn[x],z),x=fa[tp[x]])
    		if(dep[tp[x]]<dep[tp[y]]) swap(x,y);
    	fixson(min(dfn[x],dfn[y]),max(dfn[x],dfn[y]),z);
    }
    lng sumdis(int x,int y){ //求最短路径权值和
    	lng res=0;
    	for(;tp[x]!=tp[y];(res+=sumson(dfn[tp[x]],dfn[x]))%=mod,x=fa[tp[x]])
    		if(dep[tp[x]]<dep[tp[y]]) swap(x,y);
    	return (res+sumson(min(dfn[x],dfn[y]),max(dfn[x],dfn[y])))%mod;
    }
    

    然后例题就迎刃而解了。


    代码

    #include <bits/stdc++.h>
    using namespace std;
    
    //Start
    #define lng long long
    #define db double
    #define mk make_pair
    #define pb push_back
    #define fi first
    #define se second
    #define rz resize
    const int inf=0x3f3f3f3f;
    const lng INF=0x3f3f3f3f3f3f3f3f;
    
    //Data
    const int N=1e5;
    int n,m,rt; lng mod,a[N+7]; vector<int> e[N+7];
    
    //Treesplit
    int ind,fa[N+7],sz[N+7],dep[N+7],son[N+7],tp[N+7],dfn[N+7],rk[N+7];
    void Dfs1(int x){
    	sz[x]=1,dep[x]=dep[fa[x]]+1;
    	for(int to:e[x])if(to!=fa[x]){
    		fa[to]=x,Dfs1(to),sz[x]+=sz[to];
    		if(sz[to]>sz[son[x]]) son[x]=to;
    	}
    }
    void Dfs2(int x,int an){
    	tp[x]=an,dfn[x]=++ind,rk[ind]=x;
    	if(son[x]) Dfs2(son[x],an);
    	for(int to:e[x]) if(to!=fa[x]&&to!=son[x]) Dfs2(to,to);
    }
    lng v[(N<<2)+7],ma[(N<<2)+7];
    #define mid ((l+r)>>1)
    void pushdown(int k,int l,int r){
    	if(!ma[k]) return;
    	(v[k<<1]+=ma[k]*(mid-l+1))%=mod;
    	(v[k<<1|1]+=ma[k]*(r-mid))%=mod;
    	(ma[k<<1]+=ma[k])%=mod;
    	(ma[k<<1|1]+=ma[k])%=mod;
    	ma[k]=0;
    }
    void build(int k=1,int l=1,int r=n){
    	if(l==r){v[k]=a[rk[l]];return;}
    	build(k<<1,l,mid),build(k<<1|1,mid+1,r),v[k]=v[k<<1]+v[k<<1|1];
    }
    void fixson(int x,int y,lng z,int k=1,int l=1,int r=n){
    	if(x<=l&&r<=y){(ma[k]+=z)%=mod,v[k]+=z*(r-l+1);return;}
    	pushdown(k,l,r);
    	if(mid>=x)fixson(x,y,z,k<<1,l,mid);
    	if(mid<y)fixson(x,y,z,k<<1|1,mid+1,r);
    	v[k]=v[k<<1]+v[k<<1|1];
    }
    lng sumson(int x,int y,int k=1,int l=1,int r=n){
    	if(x<=l&&r<=y) return v[k];
    	lng res=0;
    	pushdown(k,l,r);
    	if(mid>=x) res+=sumson(x,y,k<<1,l,mid);
    	if(mid<y)res+=sumson(x,y,k<<1|1,mid+1,r);
    	return res%mod;
    }
    void fixdis(int x,int y,int z){
    	for(;tp[x]!=tp[y];fixson(dfn[tp[x]],dfn[x],z),x=fa[tp[x]])
    		if(dep[tp[x]]<dep[tp[y]]) swap(x,y);
    	fixson(min(dfn[x],dfn[y]),max(dfn[x],dfn[y]),z);
    }
    lng sumdis(int x,int y){
    	lng res=0;
    	for(;tp[x]!=tp[y];(res+=sumson(dfn[tp[x]],dfn[x]))%=mod,x=fa[tp[x]])
    		if(dep[tp[x]]<dep[tp[y]]) swap(x,y);
    	return (res+sumson(min(dfn[x],dfn[y]),max(dfn[x],dfn[y])))%mod;
    }
    
    //Main
    int main(){
    	scanf("%d%d%d%lld",&n,&m,&rt,&mod);
    	for(int i=1;i<=n;i++) scanf("%lld",&a[i]);
    	for(int i=1,x,y;i<=n-1;i++) scanf("%d%d",&x,&y),e[x].pb(y),e[y].pb(x);
    	Dfs1(rt),Dfs2(rt,rt),build();
    	for(int i=1;i<=m;i++){
    		int o,x,y; lng z; scanf("%d",&o);
    		if(o==1) scanf("%d%d%lld",&x,&y,&z),fixdis(x,y,z);
    		else if(o==2) scanf("%d%d",&x,&y),printf("%lld
    ",sumdis(x,y));
    		else if(o==3) scanf("%d%lld",&x,&z),fixson(dfn[x],dfn[x]+sz[x]-1,z);
    		else if(o==4) scanf("%d",&x),printf("%lld
    ",sumson(dfn[x],dfn[x]+sz[x]-1));
    	}
    	return 0;
    }
    

    特例

    有些题目要维护的是边的权值(例题维护的是点的权值),比如[USACO11DEC]Grass Planting G,方法是用点存储它与父亲节点相连的边的权值,具体看代码:

    #include <bits/stdc++.h>
    using namespace std;
    
    //Start
    #define lng long long
    #define db double
    #define mk make_pair
    #define pb push_back
    #define fi first
    #define se second
    #define rz resize
    const int inf=0x3f3f3f3f;
    const lng INF=0x3f3f3f3f3f3f3f3f;
    
    //Data
    const int N=1e5;
    int n,m;
    vector<int> e[N+7];
    
    //Treesplit
    int ind,fa[N+7],son[N+7],sz[N+7],dep[N+7],dfn[N+7],rk[N+7],tp[N+7];
    void Dfs1(int x){
    	sz[x]=1,dep[x]=dep[fa[x]]+1;
    	for(int to:e[x])if(to!=fa[x]){
    		fa[to]=x,Dfs1(to),sz[x]+=sz[to];
    		if(sz[to]>sz[son[x]]) son[x]=to;
    	}
    }
    void Dfs2(int x,int f){
    	tp[x]=f,dfn[x]=++ind,rk[ind]=x;
    	if(son[x]) Dfs2(son[x],f);
    	for(int to:e[x])if(to!=fa[x]&&to!=son[x]) Dfs2(to,to);
    }
    int v[(N<<2)+7],ma[(N<<2)+7];
    #define mid ((l+r)>>1)
    void pushdown(int k,int l,int r){
    	if(!ma[k]) return;
    	v[k<<1]+=ma[k],v[k<<1|1]+=ma[k];
    	ma[k<<1]+=ma[k],ma[k<<1|1]+=ma[k],ma[k]=0;
    }
    void fixson(int x,int y,int z,int k=1,int l=1,int r=n){
    	if(x<=l&&r<=y){v[k]+=z,ma[k]+=z;return;}
    	if(l==r) return; // 因为可能出现无边情况
    	pushdown(k,l,r);
    	if(mid>=x) fixson(x,y,z,k<<1,l,mid);
    	if(mid<y) fixson(x,y,z,k<<1|1,mid+1,r);
    	v[k]=max(v[k<<1],v[k<<1|1]);
    }
    int sumson(int x,int y,int k=1,int l=1,int r=n){
    	if(x<=l&&r<=y) return v[k];
    	if(l==r) return 0; // 因为可能出现无边情况
    	int res=0;
    	pushdown(k,l,r);
    	if(mid>=x) res=max(res,sumson(x,y,k<<1,l,mid));
    	if(mid<y) res=max(res,sumson(x,y,k<<1|1,mid+1,r));
    	return res;
    }
    void fixdis(int x,int y,int z){
    	for(;tp[x]!=tp[y];fixson(dfn[tp[x]],dfn[x],z),x=fa[tp[x]])
    		if(dep[tp[x]]<dep[tp[y]]) swap(x,y);
    	fixson(min(dfn[x],dfn[y])+1,max(dfn[x],dfn[y]),z); //+1 表示最短路径的最近公共祖先和其父亲连的边不改
    }
    int sumdis(int x,int y){
    	int res=0;
    	for(;tp[x]!=tp[y];res+=sumson(dfn[tp[x]],dfn[x]),x=fa[tp[x]])
    		if(dep[tp[x]]<dep[tp[y]]) swap(x,y);
    	return res+sumson(min(dfn[x],dfn[y])+1,max(dfn[x],dfn[y])); //+1 表示最短路径的最近公共祖先和其父亲连的边不算
    }
    
    //Main
    int main(){
    	scanf("%d%d",&n,&m);
    	for(int i=1,x,y;i<n;i++) scanf("%d%d",&x,&y),e[x].pb(y),e[y].pb(x);
    	Dfs1(1),Dfs2(1,1);
    	for(int i=1;i<=m;i++){
    		vector<char> s(5); int x,y;
    		scanf("%s%d%d",&s[1],&x,&y);
    		if(s[1]=='P') fixdis(x,y,1);
    		else if(s[1]=='Q') printf("%d
    ",sumdis(x,y));
    	}
    	return 0;
    }
    

    然后就这样讲完了,非常简单的东西。例题就不给了,到洛谷上随便撸几道即可。


    祝大家学习愉快!

  • 相关阅读:
    linux系统——机制与策略(三)
    linux系统——机制与策略(二)
    Linux系统——机制策略(一)
    RTSP会话基本流程
    linux编程学习
    编码风格——linux内核开发的coding style
    编程风格——整洁代码的4个提示
    编程风格——五种应该避免的代码注释
    十条不错的编程观点
    代码优化概要
  • 原文地址:https://www.cnblogs.com/George1123/p/12674071.html
Copyright © 2011-2022 走看看