zoukankan      html  css  js  c++  java
  • 初识树链剖分

    首发于摸鱼世界&更好的阅读体验

    到现在也只会照着std打板子..

    虽然这样,树链剖分还是一个非常优雅的算法。


    前置芝士:(DFS),线段树

    树链剖分可以把树上的区间操作通过把树剖成一条条链,利用线段树数据结构进行维护,从而达到(O(nlogn))的优秀时间复杂度。

    比如这样的操作:

    在一棵树上,将(x)(y)路径上点的点权加上(w),并要求支持查询两个点(x,y)路径间的点权和。

    乍一看,两个操作都很简单。修改操作可以用树上差分(O(1))乱搞,静态查询可以用(LCA)完成。

    但是合起来就没有办法了:每次查询之前都需要(O(n))预处理,数据略大直接(T)飞。

    于是树剖出场了。


    区间修改&查询是线段树的强项,但是它只能对一段连续的区间进行查询。于是我们需要想办法让树上需要操作的路径变成一段连续的区间。

    引入一个概念:重儿子,也就是一个节点的儿子中(size)最大的。连接到重儿子的边即为重边

    重儿子组成的,就是重链

    比如在这棵树中,连续的红边组成的就是一条条重链。我们用(top[u])记录节点(u)所在重链的顶端。特别地,没有被重边连接的节点,(top[u]=u),即它们所在重链的顶端就是自身。注意到,当(u)是一条重链的顶端((top[u]=u))时,它的父节点一定在另一条重链上

    始终记住我们的目标:把在树上区间操作转化为在一段连续的区间进行操作。

    考虑如何用(DFS)给树上的每个节点在区间内找到一个合适的位置。我们发现,从根节点出发,优先走重边,这样的(dfs)序似乎有点特殊。

    例如上图,优先走重边的(dfs)序为:(124798356)。很显然,这样的(dfs)序满足同一条重链上的点(dfs)序连续。所以用线段树维护的,就是重链上的信息

    这样操作之后,我们可以做到的是:(O(logn))对一条重链上的信息区间修改,区间查询。

    对于两个节点(u,v),我们可以通过不断地跳重链,直到两个节点在同一条重链上。这个是很好实现的,因为只需要跳到(fa[top[u]]),就到了一条新的重链。

    代码实现仅树剖部分是不麻烦的。我们需要维护的信息有(dep)(节点深度),(fa)(父节点),(son)(重儿子),(sz)(子树节点数,用来判重儿子),这些可以用一次(dfs)完成。

    void dfs1(int u,int f,int d)//fa,dep,son,sz
    {
    	fa[u]=f;
    	dep[u]=d;
    	sz[u]=1;
    	for(int i=head[u];i;i=nxt[i])
    	{
    		int v=to[i];
    		if(v!=f)
    		{
    			dfs1(v,u,d+1);
    			sz[u]+=sz[v];
    			if(sz[v]>sz[son[u]])son[u]=v;
    		}
    
    	}
    }
    

    接下来,就需要把这棵树每个节点压到线段树维护的序列的一个位置了。就像上文说的一样,按照优先重边(dfs)序压入线段树即可。于是记录一个(id[i])表示原树中节点(i)对应的线段树中的下标。(rk[i])反过来记录线段树中下标为(i)的原数编号。

    由于预处理了父节点,所以(dfs2)传参只需要(u)(当前节点)和(t)(当前重链顶端节点)。在遍历儿子之前先(dfs2(son[u],t)),因为(u)(u)的重儿子在同一条重链上。接下来才遍历轻(非重)儿子(v),但是传参为(dfs2(v,v)),因为(v)就是新的一条重链的起点。

    void dfs2(int u,int t)//top,id,rk
    {
    	top[u]=t;
    	id[u]=++tot;
    	rk[tot]=u;
    	if(!son[u])return;
    	dfs2(son[u],t);
    	for(int i=head[u];i;i=nxt[i])
    	{
    		int v=to[i];
    		if(v!=fa[u]&&v!=son[u])
    			dfs2(v,v);
    	}
    }
    

    再回到最开始的问题:

    在一棵树上,将(x)(y)路径上点的点权加上(w),并要求支持查询两个点(x,y)路径间的点权和。

    答案就显得很明了了。

    如果是查询,先保证(dep[x]>dep[y]),然后就和(LCA)类似的,利用重链加速:每次把([top[x],x])这条重链的和累加到答案上,再使(x)跳到另一条重链上,即(x=fa[top[x]]),直到(x,y)在同一条重链上,再把两个点之间的信息统计累加一下即可。

    int getsum(int x,int y)
    {
    	int res=0;
    	while(top[x]!=top[y])
    	{
    		if(dep[top[x]]<dep[top[y]])swap(x,y);
    		sum=0;
    		asksum(1,id[top[x]],id[x]);
    		(res+=sum)%=mod;
    		x=fa[top[x]];
    	}
    	if(id[x]>id[y])swap(x,y);
    	sum=0;
    	asksum(1,id[x],id[y]);
    	(res+=sum)%=mod;
    	return res;
    }
    

    修改同理。

    于是我们发现,虽然我们采用了优先重边的(dfs)序,但它毕竟遍历的都是自己的儿子节点。所以...还可以支持子树操作。因为一棵子树在重边优先的(dfs)序中编号也是连续的。并且这个编号很容易算,因为我们维护了一个(sz)信息。所以树中(x)节点的子树对应的就是线段树维护的([id[x],id[x]+sz[x]-1])这个区间

    于是还是板子一般的线段树区间修改&查询。


    可以注意到线段树部分基本没讲,因为每个人写线段树的方法可能不太一样,蒟蒻我分享的只是树剖的思想。

    另外,为什么树剖每次操作是(O(logn))呢?利用线段树的子树操作自然是(O(logn)),剩下的就是那个像(LCA)一样的跳重链。

    证明:从任意节点向根节点跳重链,经过的重链和轻边(非重边)都是(log)级别的。

    考虑到每走一条轻边,子树大小至少翻倍,否则这就不是条轻边了。于是经过的轻边就最多为(log_2 n)条。而重链和轻边的交替出现的,所以数量也在这个级别。

    于是每次操作就只有(O(logn))的时间复杂度。

    模板题

    以下是代码

    #include<bits/stdc++.h>
    #define int long long
    #define ls (k<<1)
    #define rs (k<<1|1)
    using namespace std;
    const int N=1e5+10;
    struct node
    {
    	int l,r,w,f;
    }t[N<<2];
    int a[N];
    int n,m,r,mod;
    int sum;
    int head[N<<1],to[N<<1],nxt[N<<1],cnt;
    int sz[N],fa[N],dep[N],son[N];
    int top[N],id[N],rk[N],tot;
    inline int read()
    {
       int x=0,f=1;
       char ch=getchar();
       while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
       while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
       return x*f;
    }
    void add(int u,int v)
    {
    	cnt++;
    	to[cnt]=v;
    	nxt[cnt]=head[u];
    	head[u]=cnt;
    }
    void dfs1(int u,int f)
    {
    	fa[u]=f;
    	sz[u]=1;
    	dep[u]=dep[f]+1;
    	for(int i=head[u];i;i=nxt[i])
    	{
    		int v=to[i];
    		if(v==f)continue;
    		dfs1(v,u);
    		sz[u]+=sz[v];
    		if(sz[v]>sz[son[u]])son[u]=v;
    	}
    	return;
    }
    void dfs2(int u,int t)
    {
    	top[u]=t;
    	id[u]=++tot;
    	rk[tot]=u;
    	if(!son[u])return;
    	dfs2(son[u],t);
    	for(int i=head[u];i;i=nxt[i])
    	{
    		int v=to[i];
    		if(v!=fa[u]&&v!=son[u])dfs2(v,v);//新的重链 
    	}
    }
    void build(int k,int l,int r)
    {
    	t[k].l=l,t[k].r=r;
    	if(l==r)
    	{
    		t[k].w=a[rk[l]];
    		return;
    	}
    	int m=l+r>>1;
    	build(ls,l,m);
    	build(rs,m+1,r);
    	t[k].w=t[ls].w+t[rs].w;
    	return;
    }
    void down(int k)
    {
    	t[ls].w+=(t[ls].r-t[ls].l+1)*t[k].f;
    	t[rs].w+=(t[rs].r-t[rs].l+1)*t[k].f;
    	t[ls].f+=t[k].f;
    	t[rs].f+=t[k].f;
    	t[k].f=0;
    }
    void addsum(int k,int x,int y,int p)
    {
    	int l=t[k].l,r=t[k].r;
    	if(x<=l&&r<=y)
    	{
    		t[k].w+=(r-l+1)*p;
    		t[k].f+=p;
    		return;
    	}
    	down(k);
    	int m=l+r>>1;
    	if(x<=m)addsum(ls,x,y,p);
    	if(y>m)addsum(rs,x,y,p);
    	t[k].w=t[ls].w+t[rs].w;
    	return;
    }
    void asksum(int k,int x,int y)
    {
    	int l=t[k].l,r=t[k].r;
    	if(x<=l&&r<=y)
    	{
    		sum+=t[k].w;
    		return;
    	}
    	down(k);
    	int m=l+r>>1;
    	if(x<=m)asksum(ls,x,y);
    	if(y>m)asksum(rs,x,y);
    	t[k].w=t[ls].w+t[rs].w;
    	return;
    }
    //-----------------------------
    int getsum(int x,int y)
    {
    	int res=0;
    	while(top[x]!=top[y])
    	{
    		if(dep[top[x]]<dep[top[y]])swap(x,y);
    		sum=0;
    		asksum(1,id[top[x]],id[x]);
    		(res+=sum)%=mod;
    		x=fa[top[x]];
    	}
    	if(id[x]>id[y])swap(x,y);
    	sum=0;
    	asksum(1,id[x],id[y]);
    	(res+=sum)%=mod;
    	return res;
    }
    void update(int x,int y,int p)
    {
    	while(top[x]!=top[y])
    	{
    		if(dep[top[x]]<dep[top[y]])swap(x,y);
    		addsum(1,id[top[x]],id[x],p);
    		x=fa[top[x]];
    	}
    	if(id[x]>id[y])swap(x,y);
    	addsum(1,id[x],id[y],p);
    	return;
    }
    signed main()
    {
    	n=read(),m=read(),r=read(),mod=read();
    	for(int i=1;i<=n;i++)a[i]=read();
    	for(int i=1;i<n;i++)
    	{
    		int x=read(),y=read();
    		add(x,y),add(y,x);
    	}
    	dfs1(r,0);
    	dfs2(r,r);
    	build(1,1,n);
    	for(int i=1;i<=m;i++)
    	{
    		int x,y,z;
    		int opt=read();
    		if(opt==1)
    		{
    			x=read(),y=read(),z=read();
    			update(x,y,z);
    		}
    		if(opt==2)
    		{
    			x=read(),y=read();
    			printf("%lld
    ",getsum(x,y)%mod);
    		}
    		if(opt==3)
    		{
    			x=read(),z=read();
    			addsum(1,id[x],id[x]+sz[x]-1,z);
    		}
    		if(opt==4)
    		{
    			x=read();
    			sum=0;asksum(1,id[x],id[x]+sz[x]-1);
    			printf("%lld
    ",sum%mod);
    		}
    	}
    	return 0;
    }
    

    代码的确是长,也不算容易调,但是真正妙的是利用轻重链的思想进行的化树为链。

    感谢阅读。

  • 相关阅读:
    appium知识01-环境设置
    移动端测试基础知识02
    魔术方法和反射
    面向对象开发: 封装, 继承, 多态
    正则的用法
    内置方法, 第三方模块(math, random, pickle, json, time, os, shutil, zip, tarfile), 导入包
    推导式(列表, 集合, 字典), 生成器
    迭代器, 高阶函数(map, filter, reduce, sorted) , 递归函数
    函数globals和locals用法, LEGB原则, 闭包函数 , 匿名函数
    字符串, 列表, 元祖, 集合, 字典的相关操作和函数, 深浅copy
  • 原文地址:https://www.cnblogs.com/moyujiang/p/13446439.html
Copyright © 2011-2022 走看看