zoukankan      html  css  js  c++  java
  • 树链剖分详解

    树链剖分是个很简单的算法

    树链剖分一共分为两种,一种是重链剖分,比较常见;还有一种是长链剖分,比较少见

    一.重链剖分

    以下讲解都以Luogu P3384 【模板】树链剖分为例

    重儿子:对于每一个非叶子节点,它的儿子中 以那个儿子为根的子树节点数最大的儿子 为该节点的重儿子 (Ps: 感谢@shzr大佬指出我此句话的表达不严谨qwq, 已修改)

    轻儿子:对于每一个非叶子节点,它的儿子中 非重儿子 的剩下所有儿子即为轻儿子

    叶子节点没有重儿子也没有轻儿子(因为它没有儿子。。)

    重边:一个父亲连接他的重儿子的边称为重边 //原写法:连接任意两个重儿子的边叫做重边

    轻边:剩下的即为轻边

    重链:相邻重边连起来的 连接一条重儿子 的链叫重链

    对于叶子节点,若其为轻儿子,则有一条以自己为起点的长度为1的链

    每一条重链以轻儿子为起点

    1256986-20171203120143991-1630008815.png

    这图好像是洛咕上的,我还是懒得自己画

    说起来这些概念实际很简单

    但写起来还是要有较强码力的

    我们先要写把轻重链求出的函数

    一共需要写两个函数

    1.dfs1

    dfs1主要求出:

    1.该节点的子树大小(1+所有子节点子树大小之和)

    2.重儿子(找到所有子节点中子树大小最大的)

    3.父节点

    4.深度

    dfs1还是比较简单的qaq

    inline void dfs1(register int x)
    {
    	size[x]=1;
    	for(register int i=head[x];i;i=e[i].next)
    		if(e[i].to!=fa[x])
    		{
    			dep[e[i].to]=dep[x]+1;
    			fa[e[i].to]=x;
    			dfs1(e[i].to);
    			size[x]+=size[e[i].to];
    			if(size[e[i].to]>size[son[x]])
    				son[x]=e[i].to;
    		}
    }
    

    dfs2

    dfs2是重链剖分的重点

    dfs2要求出:

    1.树的dfs序(优先搜重儿子)

    2.在树的dfs序之下,珂以把树上的值存到连续的数列中,到时就珂以线段树维护

    3.每个重链的顶端,方便到时候跳链(不懂的话后面会讲)

    inline void dfs2(register int x,register int t)
    {
    	dl[x]=++tot;
    	a[tot]=ch[x];
    	top[x]=t;
    	if(son[x])
    		dfs2(son[x],t);
    	for(register int i=head[x];i;i=e[i].next)
    		if(e[i].to!=fa[x]&&e[i].to!=son[x])
    			dfs2(e[i].to,e[i].to);
    }
    

    跑完两个dfs之后就珂以用线段树

    build建树:

    inline void pushup(register int x)
    {
    	sum[x]=sum[x<<1]+sum[x<<1|1];
    	sum[x]%=mod;
    }
    inline void build(register int x,register int l,register int r)
    {
    	if(l==r)
    	{
    		sum[x]=a[l];
    		tag[x]=0;
    		return;
    	}
    	int mid=l+r>>1;
    	build(x<<1,l,mid);
    	build(x<<1|1,mid+1,r);
    	pushup(x);
    }
    

    下面是处理查询

    操作1:把x节点到y节点路径上的值加z

    这里需要一个跳链的函数——cal1

    inline void pushdown(register int x,register int l,register int r)
    {
    	int ls=x<<1,rs=x<<1|1,mid=l+r>>1;
    	sum[ls]+=(mid-l+1)*tag[x];
    	sum[rs]+=(r-mid)*tag[x];
    	tag[ls]+=tag[x];
    	tag[rs]+=tag[x];
    	sum[ls]%=mod;
    	sum[rs]%=mod;
    	tag[ls]%=mod;
    	tag[rs]%=mod;
    	tag[x]=0;
    }
    inline void update(register int x,register int l,register int r,register int L,register int R,register int k)
    {
    	if(L<=l&&r<=R)
    	{
    		sum[x]+=(r-l+1)*k;
    		tag[x]+=k;
    		sum[x]%=mod;
    		tag[x]%=mod;
    		return;
    	}
    	if(tag[x])
    		pushdown(x,l,r);
    	int mid=l+r>>1;
    	if(L<=mid)
    		update(x<<1,l,mid,L,R,k);
    	if(R>=mid+1)
    		update(x<<1|1,mid+1,r,L,R,k);
    	pushup(x);
    }
    inline void cal1(register int x,register int y,register int z)
    {
    	int fx=top[x],fy=top[y];
    	while(fx!=fy)
    	{
    		if(dep[fx]<dep[fy])
    		{
    			swap(x,y);
    			swap(fx,fy);
    		}
    		update(1,1,tot,dl[fx],dl[x],z);
    		x=fa[fx];
    		fx=top[x];
    	}
    	if(dl[x]>dl[y])
    		swap(x,y);
    	update(1,1,tot,dl[x],dl[y],z);
    }
    

    操作2:查询x到y路径点权之和

    和操作1差不多,需要跳链

    inline ll query(register int x,register int l,register int r,register int L,register int R)
    {
    	if(L<=l&&r<=R)
    		return sum[x];
    	if(tag[x])
    		pushdown(x,l,r);
    	ll res=0;
    	int mid=l+r>>1;
    	if(L<=mid)
    		res+=query(x<<1,l,mid,L,R)%mod;
    	if(R>=mid+1)
    		res+=query(x<<1|1,mid+1,r,L,R)%mod;
    	return res%mod;
    }
    inline ll cal2(register int x,register int y)
    {
    	ll res=0;
    	int fx=top[x],fy=top[y];
    	while(fx!=fy)
    	{
    		if(dep[fx]<dep[fy])
    		{
    			swap(x,y);
    			swap(fx,fy);
    		}
    		res=(res%mod+query(1,1,tot,dl[fx],dl[x])%mod)%mod;
    		x=fa[fx];
    		fx=top[x];
    	}
    	if(dl[x]>dl[y])
    		swap(x,y);
    	res=(res%mod+query(1,1,tot,dl[x],dl[y])%mod)%mod;
    	return res%mod;
    }
    

    操作3:把x的子树内所有节点全值加z

    考虑到子树内dfs序是相连的

    所以被修改区间是一个连续的区间,所以直接上线段树

    update(1,1,tot,dl[x],dl[x]+size[x]-1,z%mod);
    

    操作四:求x的子树内所有节点的和

    和操作3一样,珂以直接用线段树

    write(query(1,1,tot,dl[x],dl[x]+size[x]-1)%mod);
    

    最后上一下重链剖分整体代码

    #include <bits/stdc++.h>
    #define ll long long
    #define N 100005
    using namespace std;
    inline ll read()
    {
    	register ll x=0,f=1;register char ch=getchar();
    	while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    	while(ch>='0'&&ch<='9')x=(x<<3)+(x<<1)+ch-'0',ch=getchar();
    	return x*f;
    }
    inline void write(register ll x)
    {
    	if(!x)putchar('0');if(x<0)x=-x,putchar('-');
    	static int sta[36];int tot=0;
    	while(x)sta[tot++]=x%10,x/=10;
    	while(tot)putchar(sta[--tot]+48);
    }
    struct node{
    	int to,next;
    }e[N<<1];
    int head[N],cnt=0;
    inline void add(register int u,register int v)
    {
    	e[++cnt]=(node){v,head[u]};
    	head[u]=cnt;
    }
    ll ch[N];
    ll n,m,rt,mod;
    ll size[N],dep[N],fa[N],son[N];
    ll tot=0,dl[N],a[N],top[N];
    inline void dfs1(register int x)
    {
    	size[x]=1;
    	for(register int i=head[x];i;i=e[i].next)
    		if(e[i].to!=fa[x])
    		{
    			dep[e[i].to]=dep[x]+1;
    			fa[e[i].to]=x;
    			dfs1(e[i].to);
    			size[x]+=size[e[i].to];
    			if(size[e[i].to]>size[son[x]])
    				son[x]=e[i].to;
    		}
    }
    inline void dfs2(register int x,register int t)
    {
    	dl[x]=++tot;
    	a[tot]=ch[x];
    	top[x]=t;
    	if(son[x])
    		dfs2(son[x],t);
    	for(register int i=head[x];i;i=e[i].next)
    		if(e[i].to!=fa[x]&&e[i].to!=son[x])
    			dfs2(e[i].to,e[i].to);
    }
    ll sum[N<<3],tag[N<<3];
    inline void pushup(register int x)
    {
    	sum[x]=sum[x<<1]+sum[x<<1|1];
    	sum[x]%=mod;
    }
    inline void build(register int x,register int l,register int r)
    {
    	if(l==r)
    	{
    		sum[x]=a[l];
    		tag[x]=0;
    		return;
    	}
    	int mid=l+r>>1;
    	build(x<<1,l,mid);
    	build(x<<1|1,mid+1,r);
    	pushup(x);
    }
    inline void pushdown(register int x,register int l,register int r)
    {
    	int ls=x<<1,rs=x<<1|1,mid=l+r>>1;
    	sum[ls]+=(mid-l+1)*tag[x];
    	sum[rs]+=(r-mid)*tag[x];
    	tag[ls]+=tag[x];
    	tag[rs]+=tag[x];
    	sum[ls]%=mod;
    	sum[rs]%=mod;
    	tag[ls]%=mod;
    	tag[rs]%=mod;
    	tag[x]=0;
    }
    inline void update(register int x,register int l,register int r,register int L,register int R,register int k)
    {
    	if(L<=l&&r<=R)
    	{
    		sum[x]+=(r-l+1)*k;
    		tag[x]+=k;
    		sum[x]%=mod;
    		tag[x]%=mod;
    		return;
    	}
    	if(tag[x])
    		pushdown(x,l,r);
    	int mid=l+r>>1;
    	if(L<=mid)
    		update(x<<1,l,mid,L,R,k);
    	if(R>=mid+1)
    		update(x<<1|1,mid+1,r,L,R,k);
    	pushup(x);
    }
    inline ll query(register int x,register int l,register int r,register int L,register int R)
    {
    	if(L<=l&&r<=R)
    		return sum[x];
    	if(tag[x])
    		pushdown(x,l,r);
    	ll res=0;
    	int mid=l+r>>1;
    	if(L<=mid)
    		res+=query(x<<1,l,mid,L,R)%mod;
    	if(R>=mid+1)
    		res+=query(x<<1|1,mid+1,r,L,R)%mod;
    	return res%mod;
    }
    inline void cal1(register int x,register int y,register int z)
    {
    	int fx=top[x],fy=top[y];
    	while(fx!=fy)
    	{
    		if(dep[fx]<dep[fy])
    		{
    			swap(x,y);
    			swap(fx,fy);
    		}
    		update(1,1,tot,dl[fx],dl[x],z);
    		x=fa[fx];
    		fx=top[x];
    	}
    	if(dl[x]>dl[y])
    		swap(x,y);
    	update(1,1,tot,dl[x],dl[y],z);
    }
    inline ll cal2(register int x,register int y)
    {
    	ll res=0;
    	int fx=top[x],fy=top[y];
    	while(fx!=fy)
    	{
    		if(dep[fx]<dep[fy])
    		{
    			swap(x,y);
    			swap(fx,fy);
    		}
    		res=(res%mod+query(1,1,tot,dl[fx],dl[x])%mod)%mod;
    		x=fa[fx];
    		fx=top[x];
    	}
    	if(dl[x]>dl[y])
    		swap(x,y);
    	res=(res%mod+query(1,1,tot,dl[x],dl[y])%mod)%mod;
    	return res%mod;
    }
    int main()
    {
    	n=read(),m=read(),rt=read(),mod=read();	
    	for(register int i=1;i<=n;++i)
    		ch[i]=read(),ch[i]%=mod;
    	for(register int i=1;i<n;++i)
    	{
    		int u=read(),v=read();
    		add(u,v),add(v,u);
    	}
    	dep[rt]=1;
    	fa[rt]=rt;
    	dfs1(rt);
    	dfs2(rt,rt);
    	build(1,1,n);
    	while(m--)
    	{
    		int opt=read();
    		if(opt==1)
    		{
    			int x=read(),y=read(),z=read();
    			cal1(x,y,z%mod);
    		}
    		else if(opt==2)
    		{
    			int x=read(),y=read();
    			write(cal2(x,y)%mod);
    			printf("
    ");
    		}
    		else if(opt==3)
    		{
    			int x=read(),z=read();
    			update(1,1,tot,dl[x],dl[x]+size[x]-1,z%mod);
    		}
    		else
    		{
    			int x=read();
    			write(query(1,1,tot,dl[x],dl[x]+size[x]-1)%mod);
    			printf("
    ");
    		}
    	}
    	return 0;
    }
    

    相关题目

    1.Luogu P2146 [NOI2015]软件包管理器

    树剖练手好题

    2.Luogu CF343D Water Tree

    树剖后用珂朵莉树

    3.Luogu CF375D Tree and Queries

    树剖后莫队暴力求解(也可以称之为树的dfs序)

    4.Luogu P4069 [SDOI2016]游戏

    树剖+李超线段树

    长链剖分

    咕咕咕

  • 相关阅读:
    从1到n中找到任意num个数的和为sum的所有组合
    算法导论5.12
    使用c++技术实现下载网页
    算法导论5.13
    感慨
    算法导论2.37
    [转载]Yahoo!的分布式数据平台PNUTS简介及感悟
    Bigtable 论文笔记
    GFS 论文笔记
    MapReduce论文笔记
  • 原文地址:https://www.cnblogs.com/yzhang-rp-inf/p/9966136.html
Copyright © 2011-2022 走看看