zoukankan      html  css  js  c++  java
  • 树链剖分学习笔记

    简介

    树链剖分,顾名思义,就是把树剖分成链,在链上进行一系列的操作。

    下面我们就来学习一下这个算法。

    概念

    树链剖分引入了很多新的概念:

    1. 重儿子:一个节点所有的儿子中子树(size)最大的儿子。
    2. 轻儿子:一个节点的儿子中除了重儿子都是轻儿子。
    3. 重边:一个节点与它的重儿子所组成的边。
    4. 轻边:一个节点与它的轻儿子组成的边。
    5. 重链:若干条重边组成的链。
    6. 轻链:若干条轻边组成的链。

    思想

    树链剖分经常与线段树相结合进行链上的操作。

    因此线段树是必须要掌握的。

    树链剖分一开始要进行(2)(dfs)

    第一次(dfs)需要记录出一个节点的父亲、节点的深度和节点的重儿子。

    第二次(dfs)需要对每个节点进行重新标号,按照重儿子优先的顺序遍历;还要记录出节点所在链的顶端;以及当前标号的点的编号。

    然后就是线段树的基本操作。

    对链进行维护时需要将两端点往上跳,直到它们在同一条剖分好的链上。

    代码

    这里以ZJOI2008 树的统计为例题讲解一下树链剖分的代码。

    #include <iostream>
    #include <cstdio>
    #include <cstdlib>
    #include <cstring>
    #include <algorithm>
    #include <cmath>
    #include <cctype>
    #include <string>
    #define itn int
    #define gI gi
    
    using namespace std;
    
    inline int gi()
    {
    	int f = 1, x = 0; char c = getchar();
    	while (c < '0' || c > '9') { if (c == '-') f = -1; c = getchar();}
    	while (c >= '0' && c <= '9') { x = x * 10 + c - '0'; c = getchar();}
    	return f * x;
    }
    
    int q, n, m;
    int tot, head[100003], nxt[100003], ver[100003];
    int dfn[100003], dep[100003], fa[100003];
    int top[100003], son[100003], sz[100003];
    int pre[100003], tim;
    int a[100003];
    
    inline void add(int u, int v)//邻接表存图
    {
    	ver[++tot] = v, nxt[tot] = head[u], head[u] = tot;
    }
    
    void dfs1(itn u, int f)//第一次dfs
    {
    	fa[u] = f/*记录父亲*/, sz[u] = 1/*记录子树大小*/, dep[u] = dep[f] + 1;/*标记深度*/ 
    	int maxsize = -1;//最大子树大小
    	for (itn i = head[u]; i; i = nxt[i])//遍历子节点
    	{
    		int v = ver[i];
    		if (v == f) continue;
    		dfs1(v, u);
    		sz[u] = sz[u] + sz[v];//计算子树大小
    		if (sz[v] > maxsize)//当前子树大小超过当前最大的子树大小
    		{
    			maxsize = sz[v], son[u] = v;//更新最大子树大小并标记重儿子
    		} 
    	}
    }
    
    void dfs2(int u, int f)
    {
    	dfn[u] = ++tim/*将树重新标号*/, top[u] = f/*记录链顶*/, pre[tim] = u/*重新编号后编号为tim的节点编号*/;
    	if (son[u]) dfs2(son[u], f);//优先遍历重儿子
    	for (itn i = head[u]; i; i = nxt[i])
    	{
    		int v = ver[i];
    		if (v == son[u] || v == fa[u]) continue;//处理过就不需要再处理了
    		dfs2(v, v);//找出下一条链
    	}
    }
    
    /******以下为线段树******/
    
    int sum[400003], maxs[400003];
    
    inline int ls(int u) {return u << 1;}//左儿子
    inline int rs(int u) {return (u << 1) | 1;}//右儿子
    
    inline void pushup(int p)//上传标记
    {
    	sum[p] = sum[ls(p)] + sum[rs(p)];//区间和
    	maxs[p] = max(maxs[ls(p)], maxs[rs(p)]);//区间最大值
    }
    
    void build(int l, int r, itn p)//建树
    {
    	if (l == r) {sum[p] = maxs[p] = a[pre[l]];/*注意是pre[l]*/ return;}//子节点
    	int mid = (l + r) >> 1;
    	build(l, mid, ls(p)); build(mid + 1, r, rs(p));
    	pushup(p);//上传节点
    }
    
    void update(int x, int y, itn l, int r, int p)//更新节点信息
    {
    	if (l == r) {sum[p] = maxs[p] = y; return;}//找到了要更新的节点
    	int mid = (l + r) >> 1;
    	if (x <= mid) update(x, y, l, mid, ls(p));//左区间寻找
    	else update(x, y, mid + 1, r, rs(p));//右区间寻找
    	pushup(p);//上传节点
    }
    
    itn getmax(int ql, int qr, int l, itn r, int p)//区间最大值查找
    {
    	if (ql <= l && r <= qr) return maxs[p];//当前区间包含于要寻找的区间
    	itn mid = (l + r) >> 1, ans = -1000000000;
    	if (ql <= mid) ans = max(ans, getmax(ql, qr, l, mid, ls(p)));//向左寻找最大值
    	if (qr > mid) ans = max(ans, getmax(ql, qr, mid + 1, r, rs(p)));//向右寻找最大值
    	pushup(p);//上传节点
    	return ans;//返回答案
    }
    
    itn getans(int ql, int qr, int l, itn r, int p)//区间和查找,与区间最大值查找没有什么区别
    {
    	if (ql <= l && r <= qr) return sum[p];
    	itn mid = (l + r) >> 1, ans = 0;
    	if (ql <= mid) ans = ans + getans(ql, qr, l, mid, ls(p));
    	if (qr > mid) ans = ans + getans(ql, qr, mid + 1, r, rs(p));
    	pushup(p);
    	return ans;
    }
    
    /******以上为线段树******/
    
    inline int qmax(int l, itn r)//查找路径上最大值
    {
    	itn ans = -1000000000;
    	while (top[l] != top[r])//不在同一条链上
    	{
    		if (dep[top[l]] < dep[top[r]]) swap(l, r);//找链顶深度大的节点
    		ans = max(ans, getmax(dfn[top[l]], dfn[l], 1, n, 1));//更新最大值
    		l = fa[top[l]];//跳到当前链顶的父亲
    	}
    	if (dep[l] > dep[r]) swap(l, r);//要满足左端点深度小
    	ans = max(ans, getmax(dfn[l], dfn[r], 1, n, 1));//更新答案
    	return ans;//返回
    }
    
    inline int qsum(int l, itn r)//求路径权值和,与查找最大值同理
    {
    	itn ans = 0;
    	while (top[l] != top[r])
    	{
    		if (dep[top[l]] < dep[top[r]]) swap(l, r);
    		ans = ans + getans(dfn[top[l]], dfn[l], 1, n, 1);
    		l = fa[top[l]];
    	}
    	if (dep[l] < dep[r]) swap(l, r);
    	ans = ans + getans(dfn[r], dfn[l], 1, n, 1);
    	return ans;
    }
    
    int main()
    {
    	n = gi();
    	for (int i = 1; i < n; i+=1) 
    	{
    		int u = gI(), v = gI(); 
    		add(u, v), add(v, u);
    	}
    	for (int i = 1; i <= n; i+=1) a[i] = gi();
    	dfs1(1, -1); dfs2(1, 1); build(1, n, 1);//预处理
    	q = gi();
    	while (q--)
    	{
    		char s[10];
    		scanf("%s", s);
    		int u = gi(), v = gi();
    		if (s[1] == 'M') printf("%d
    ", qmax(u, v));//区间最大值查找
    		else if (s[1] == 'S') printf("%d
    ", qsum(u, v));//求区间和
    		else update(dfn[u], v, 1, n, 1);//更新节点
    	}
    	return 0;
    }
    

    应用

    树链剖分求( exttt{LCA})

    代码如下(以洛谷模板为例):

    #include <iostream>
    #include <cstdio>
    #include <cstdlib>
    #include <cstring>
    #include <algorithm>
    #include <cmath>
    #include <cctype>
    #define itn int
    #define gI gi
    
    using namespace std;
    
    inline int gi()
    {
    	int f = 1, x = 0; char c = getchar();
    	while (c < '0' || c > '9') { if (c == '-') f = -1; c = getchar();}
    	while (c >= '0' && c <= '9') { x = x * 10 + c - '0'; c = getchar();}
    	return f * x;
    }
    
    int n, m, rt, dfn[500003], dep[500003], fa[500003], sz[500003], son[500003], pre[500003], top[500003];
    int tot, head[2000003], nxt[2000003], ver[2000003];
    
    inline void add(itn u, int v)
    {
    	ver[++tot] = v, nxt[tot] = head[u], head[u] = tot;
    }
    
    void dfs1(int u, int f)
    {
    	fa[u] = f, dep[u] = dep[f] + 1, sz[u] = 1;
    	for (itn i = head[u]; i; i = nxt[i])
    	{
    		int v = ver[i];
    		if (v == f) continue;
    		dfs1(v, u);
    		sz[u] = sz[u] + sz[v];
    		if (sz[v] > sz[son[u]]) son[u] = v;
    	}
    }
    
    int tim;
    
    void dfs2(itn u, int f)
    {
    	top[u] = f, dfn[u] = ++tim, pre[tim] = u;
    	if (!son[u]) return;
    	dfs2(son[u], f);
    	for (int i = head[u]; i; i = nxt[i])
    	{
    		int v = ver[i];
    		if (v == fa[u] || v == son[u]) continue;
    		dfs2(v, v);
    	}
    }
    
    int main()
    {
    	n = gi(), m = gi(), rt = gi();
    	for (itn i = 1; i < n; i+=1)
    	{
    		int u = gi(), v = gi();
    		add(u, v), add(v, u);
    	}
    	dfs1(rt, rt); 
    	dfs2(rt, rt);
    	while (m--)
    	{
    		int u = gi(), v = gi();
    		while (top[u] != top[v])
    		{
    			if (dep[top[u]] < dep[top[v]]) swap(u, v);
    			u = fa[top[u]];
    		}
    		if (dep[u] < dep[v]) printf("%lld
    ", u);
    		else printf("%lld
    ", v);
    	}
    	return 0;
    }
    

    总结

    理解一个算法的思想很重要。

    代码要熟练地打出来才算真正理解。

    记录一下我踩过的坑:

    • 建树时把(pre[l])写成了(l)

    • 跳端点时没有注意左端点编号小于右端点编号;

    • 子树(size)初始化成(0)

    • ( exttt{LCA})时把<写成>

    就这样吧~

  • 相关阅读:
    1008: 约瑟夫问题
    1009: 恺撒Caesar密码
    1006: 日历问题
    1007: 生理周期
    Asp.Net Core 发布和部署( MacOS + Linux + Nginx )
    ASP.NET Core Docker部署
    Asp.Net Core 发布和部署(Linux + Jexus )
    ASP.NET Core 十种方式扩展你的 Views
    基于机器学习的web异常检测
    Disruptor深入解读
  • 原文地址:https://www.cnblogs.com/xsl19/p/shupou.html
Copyright © 2011-2022 走看看