zoukankan      html  css  js  c++  java
  • 初识树链剖分

    首发于摸鱼世界&更好的阅读体验

    到现在也只会照着std打板子..

    虽然这样,树链剖分还是一个非常优雅的算法。


    前置芝士:(DFS),线段树

    树链剖分可以把树上的区间操作通过把树剖成一条条链,利用线段树数据结构进行维护,从而达到(O(nlogn))的优秀时间复杂度。

    比如这样的操作:

    在一棵树上,将(x)(y)路径上点的点权加上(w),并要求支持查询两个点(x,y)路径间的点权和。

    乍一看,两个操作都很简单。修改操作可以用树上差分(O(1))乱搞,静态查询可以用(LCA)完成。

    但是合起来就没有办法了:每次查询之前都需要(O(n))预处理,数据略大直接(T)飞。

    于是树剖出场了。


    区间修改&查询是线段树的强项,但是它只能对一段连续的区间进行查询。于是我们需要想办法让树上需要操作的路径变成一段连续的区间。

    引入一个概念:重儿子,也就是一个节点的儿子中(size)最大的。连接到重儿子的边即为重边

    重儿子组成的,就是重链

    比如在这棵树中,连续的红边组成的就是一条条重链。我们用(top[u])记录节点(u)所在重链的顶端。特别地,没有被重边连接的节点,(top[u]=u),即它们所在重链的顶端就是自身。注意到,当(u)是一条重链的顶端((top[u]=u))时,它的父节点一定在另一条重链上

    始终记住我们的目标:把在树上区间操作转化为在一段连续的区间进行操作。

    考虑如何用(DFS)给树上的每个节点在区间内找到一个合适的位置。我们发现,从根节点出发,优先走重边,这样的(dfs)序似乎有点特殊。

    例如上图,优先走重边的(dfs)序为:(124798356)。很显然,这样的(dfs)序满足同一条重链上的点(dfs)序连续。所以用线段树维护的,就是重链上的信息

    这样操作之后,我们可以做到的是:(O(logn))对一条重链上的信息区间修改,区间查询。

    对于两个节点(u,v),我们可以通过不断地跳重链,直到两个节点在同一条重链上。这个是很好实现的,因为只需要跳到(fa[top[u]]),就到了一条新的重链。

    代码实现仅树剖部分是不麻烦的。我们需要维护的信息有(dep)(节点深度),(fa)(父节点),(son)(重儿子),(sz)(子树节点数,用来判重儿子),这些可以用一次(dfs)完成。

    void dfs1(int u,int f,int d)//fa,dep,son,sz
    {
    	fa[u]=f;
    	dep[u]=d;
    	sz[u]=1;
    	for(int i=head[u];i;i=nxt[i])
    	{
    		int v=to[i];
    		if(v!=f)
    		{
    			dfs1(v,u,d+1);
    			sz[u]+=sz[v];
    			if(sz[v]>sz[son[u]])son[u]=v;
    		}
    
    	}
    }
    

    接下来,就需要把这棵树每个节点压到线段树维护的序列的一个位置了。就像上文说的一样,按照优先重边(dfs)序压入线段树即可。于是记录一个(id[i])表示原树中节点(i)对应的线段树中的下标。(rk[i])反过来记录线段树中下标为(i)的原数编号。

    由于预处理了父节点,所以(dfs2)传参只需要(u)(当前节点)和(t)(当前重链顶端节点)。在遍历儿子之前先(dfs2(son[u],t)),因为(u)(u)的重儿子在同一条重链上。接下来才遍历轻(非重)儿子(v),但是传参为(dfs2(v,v)),因为(v)就是新的一条重链的起点。

    void dfs2(int u,int t)//top,id,rk
    {
    	top[u]=t;
    	id[u]=++tot;
    	rk[tot]=u;
    	if(!son[u])return;
    	dfs2(son[u],t);
    	for(int i=head[u];i;i=nxt[i])
    	{
    		int v=to[i];
    		if(v!=fa[u]&&v!=son[u])
    			dfs2(v,v);
    	}
    }
    

    再回到最开始的问题:

    在一棵树上,将(x)(y)路径上点的点权加上(w),并要求支持查询两个点(x,y)路径间的点权和。

    答案就显得很明了了。

    如果是查询,先保证(dep[x]>dep[y]),然后就和(LCA)类似的,利用重链加速:每次把([top[x],x])这条重链的和累加到答案上,再使(x)跳到另一条重链上,即(x=fa[top[x]]),直到(x,y)在同一条重链上,再把两个点之间的信息统计累加一下即可。

    int getsum(int x,int y)
    {
    	int res=0;
    	while(top[x]!=top[y])
    	{
    		if(dep[top[x]]<dep[top[y]])swap(x,y);
    		sum=0;
    		asksum(1,id[top[x]],id[x]);
    		(res+=sum)%=mod;
    		x=fa[top[x]];
    	}
    	if(id[x]>id[y])swap(x,y);
    	sum=0;
    	asksum(1,id[x],id[y]);
    	(res+=sum)%=mod;
    	return res;
    }
    

    修改同理。

    于是我们发现,虽然我们采用了优先重边的(dfs)序,但它毕竟遍历的都是自己的儿子节点。所以...还可以支持子树操作。因为一棵子树在重边优先的(dfs)序中编号也是连续的。并且这个编号很容易算,因为我们维护了一个(sz)信息。所以树中(x)节点的子树对应的就是线段树维护的([id[x],id[x]+sz[x]-1])这个区间

    于是还是板子一般的线段树区间修改&查询。


    可以注意到线段树部分基本没讲,因为每个人写线段树的方法可能不太一样,蒟蒻我分享的只是树剖的思想。

    另外,为什么树剖每次操作是(O(logn))呢?利用线段树的子树操作自然是(O(logn)),剩下的就是那个像(LCA)一样的跳重链。

    证明:从任意节点向根节点跳重链,经过的重链和轻边(非重边)都是(log)级别的。

    考虑到每走一条轻边,子树大小至少翻倍,否则这就不是条轻边了。于是经过的轻边就最多为(log_2 n)条。而重链和轻边的交替出现的,所以数量也在这个级别。

    于是每次操作就只有(O(logn))的时间复杂度。

    模板题

    以下是代码

    #include<bits/stdc++.h>
    #define int long long
    #define ls (k<<1)
    #define rs (k<<1|1)
    using namespace std;
    const int N=1e5+10;
    struct node
    {
    	int l,r,w,f;
    }t[N<<2];
    int a[N];
    int n,m,r,mod;
    int sum;
    int head[N<<1],to[N<<1],nxt[N<<1],cnt;
    int sz[N],fa[N],dep[N],son[N];
    int top[N],id[N],rk[N],tot;
    inline int read()
    {
       int x=0,f=1;
       char ch=getchar();
       while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
       while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
       return x*f;
    }
    void add(int u,int v)
    {
    	cnt++;
    	to[cnt]=v;
    	nxt[cnt]=head[u];
    	head[u]=cnt;
    }
    void dfs1(int u,int f)
    {
    	fa[u]=f;
    	sz[u]=1;
    	dep[u]=dep[f]+1;
    	for(int i=head[u];i;i=nxt[i])
    	{
    		int v=to[i];
    		if(v==f)continue;
    		dfs1(v,u);
    		sz[u]+=sz[v];
    		if(sz[v]>sz[son[u]])son[u]=v;
    	}
    	return;
    }
    void dfs2(int u,int t)
    {
    	top[u]=t;
    	id[u]=++tot;
    	rk[tot]=u;
    	if(!son[u])return;
    	dfs2(son[u],t);
    	for(int i=head[u];i;i=nxt[i])
    	{
    		int v=to[i];
    		if(v!=fa[u]&&v!=son[u])dfs2(v,v);//新的重链 
    	}
    }
    void build(int k,int l,int r)
    {
    	t[k].l=l,t[k].r=r;
    	if(l==r)
    	{
    		t[k].w=a[rk[l]];
    		return;
    	}
    	int m=l+r>>1;
    	build(ls,l,m);
    	build(rs,m+1,r);
    	t[k].w=t[ls].w+t[rs].w;
    	return;
    }
    void down(int k)
    {
    	t[ls].w+=(t[ls].r-t[ls].l+1)*t[k].f;
    	t[rs].w+=(t[rs].r-t[rs].l+1)*t[k].f;
    	t[ls].f+=t[k].f;
    	t[rs].f+=t[k].f;
    	t[k].f=0;
    }
    void addsum(int k,int x,int y,int p)
    {
    	int l=t[k].l,r=t[k].r;
    	if(x<=l&&r<=y)
    	{
    		t[k].w+=(r-l+1)*p;
    		t[k].f+=p;
    		return;
    	}
    	down(k);
    	int m=l+r>>1;
    	if(x<=m)addsum(ls,x,y,p);
    	if(y>m)addsum(rs,x,y,p);
    	t[k].w=t[ls].w+t[rs].w;
    	return;
    }
    void asksum(int k,int x,int y)
    {
    	int l=t[k].l,r=t[k].r;
    	if(x<=l&&r<=y)
    	{
    		sum+=t[k].w;
    		return;
    	}
    	down(k);
    	int m=l+r>>1;
    	if(x<=m)asksum(ls,x,y);
    	if(y>m)asksum(rs,x,y);
    	t[k].w=t[ls].w+t[rs].w;
    	return;
    }
    //-----------------------------
    int getsum(int x,int y)
    {
    	int res=0;
    	while(top[x]!=top[y])
    	{
    		if(dep[top[x]]<dep[top[y]])swap(x,y);
    		sum=0;
    		asksum(1,id[top[x]],id[x]);
    		(res+=sum)%=mod;
    		x=fa[top[x]];
    	}
    	if(id[x]>id[y])swap(x,y);
    	sum=0;
    	asksum(1,id[x],id[y]);
    	(res+=sum)%=mod;
    	return res;
    }
    void update(int x,int y,int p)
    {
    	while(top[x]!=top[y])
    	{
    		if(dep[top[x]]<dep[top[y]])swap(x,y);
    		addsum(1,id[top[x]],id[x],p);
    		x=fa[top[x]];
    	}
    	if(id[x]>id[y])swap(x,y);
    	addsum(1,id[x],id[y],p);
    	return;
    }
    signed main()
    {
    	n=read(),m=read(),r=read(),mod=read();
    	for(int i=1;i<=n;i++)a[i]=read();
    	for(int i=1;i<n;i++)
    	{
    		int x=read(),y=read();
    		add(x,y),add(y,x);
    	}
    	dfs1(r,0);
    	dfs2(r,r);
    	build(1,1,n);
    	for(int i=1;i<=m;i++)
    	{
    		int x,y,z;
    		int opt=read();
    		if(opt==1)
    		{
    			x=read(),y=read(),z=read();
    			update(x,y,z);
    		}
    		if(opt==2)
    		{
    			x=read(),y=read();
    			printf("%lld
    ",getsum(x,y)%mod);
    		}
    		if(opt==3)
    		{
    			x=read(),z=read();
    			addsum(1,id[x],id[x]+sz[x]-1,z);
    		}
    		if(opt==4)
    		{
    			x=read();
    			sum=0;asksum(1,id[x],id[x]+sz[x]-1);
    			printf("%lld
    ",sum%mod);
    		}
    	}
    	return 0;
    }
    

    代码的确是长,也不算容易调,但是真正妙的是利用轻重链的思想进行的化树为链。

    感谢阅读。

  • 相关阅读:
    Oracle基础知识整理
    linux下yum安装redis以及使用
    mybatis 学习四 源码分析 mybatis如何执行的一条sql
    mybatis 学习三 mapper xml 配置信息
    mybatis 学习二 conf xml 配置信息
    mybatis 学习一 总体概述
    oracle sql 语句 示例
    jdbc 新认识
    eclipse tomcat 无法加载导入的web项目,There are no resources that can be added or removed from the server. .
    一些常用算法(持续更新)
  • 原文地址:https://www.cnblogs.com/moyujiang/p/13446439.html
Copyright © 2011-2022 走看看