zoukankan      html  css  js  c++  java
  • 倍增lca学习笔记

    倍增lca学习笔记

    前置

    倍增

    例题P3865 【模板】ST表

    #include <cstdio>
    #include <algorithm>
    #include <cstring>
    #define maxn 500010
    
    using namespace std;
    
    int n, m, st[maxn][21], llog2[maxn];
    
    void init()
    {
    	for(int j = 1; (1 << j) <= n; j++)
    		for(int i = 1; i + (1 << j) - 1 <= n; i++)
    			st[i][j] = max(st[i][j - 1], st[i + (1 << (j - 1))][j - 1]);
    	for(int i = 2; i <= n; i++)
    	{
    		llog2[i] = llog2[i - 1];
    		if((1 << (llog2[i] + 1)) == i) llog2[i]++;
    	}
    }
    
    int query(int l, int r)
    {
    	int k = llog2[r - l + 1];
    	return max(st[l][k], st[r - (1 << k) + 1][k]);
    }
    
    int main() {
    	int l, r;
    	scanf("%d%d", &n, &m);
    	for(int i = 1; i <= n; i++) scanf("%d", &st[i][0]);
    	init();
    	while(m--)
    	{
    		scanf("%d%d", &l, &r);
    		printf("%d
    ", query(l, r));
    	}
    	return 0;
    }
    

    st[i][j]表示从 i 开始往后 $ 2 ^ j$个数里的最大值

    llog[maxn]是为了快速计算log

    init初始化时,用类似于递推的方法预处理好所有(2 ^ j)的长度的最大值

    并且处理好(i)对应的log值(log手动模拟一下就懂)

    查询时数据不一定刚好是(2^j),所以我们为了要让st的端点刚好覆盖查询点

    用log便可以完成这个操作,虽然最后查询时的两段st会有重叠部分,但对答案没有影响

    LCA

    有tarjan和倍增两种写法(我知道的)

    tarjan可以看看同校巨巨写的

    模板P3379 【模板】最近公共祖先(LCA)

    #include <cstdio>
    #include <algorithm>
    #include <cstring>
    #define maxn 600005
    
    using namespace std;
    
    int n, m, cnt, rt, dp[maxn][21], head[maxn], deep[maxn];
    
    struct Edge{
    	int v, next;
    }e[maxn << 1];
    
    void add(int u, int v)
    {
    	e[++cnt].v = v;
    	e[cnt].next = head[u];
    	head[u] = cnt;
    }
    
    void dfs(int x, int fa)
    {
    	dp[x][0] = fa;
    	deep[x] = deep[fa] + 1;
    	for(int i = head[x]; i; i = e[i].next)
    	{
    		int v = e[i].v;
    		if(v == fa) continue;
    		dfs(v, x);
    	}
    }
    
    void init()
    {
    	for(int j = 1; (1 << j) <= n; j++)
    		for(int i = 1; i <= n; i++)
    			if(deep[i] >= (1 << j)) 
    				dp[i][j] = dp[dp[i][j - 1]][j - 1];
    }
    
    int lca(int x, int y)
    {
    	if(deep[x] < deep[y]) swap(x, y);
    	int d = deep[x] - deep[y];
    	for(int j = 20; j >= 0; j--) if(d & (1 << j)) x = dp[x][j];
    	if(x == y) return x;
    	for(int j = 20; j >= 0; j--)
    	{
    		if(dp[x][j] != dp[y][j])
    		{
    			x = dp[x][j];
    			y = dp[y][j];
    		}
    	}
    	return dp[x][0];
    }
    
    int main() {
    	int u, v;
    	scanf("%d%d%d", &n, &m, &rt);
    	for(int i = 1; i < n; i++)
    	{
    		scanf("%d%d", &u, &v);
    		add(u, v);
    		add(v, u);
    	}
    	dfs(rt, 0);
    	init();
    	for(int i = 1; i <= m; i++)
    	{
    		scanf("%d%d", &u, &v);
    		printf("%d
    ", lca(u, v));
    	}
    	return 0;
    }
    

    dp[i][j](i)(2^j)级祖先

    链式前向星建边不说了


    void dfs(int x, int fa)
    {
    	dp[x][0] = fa;
    	deep[x] = deep[fa] + 1;
    	for(int i = head[x]; i; i = e[i].next)
    	{
    		int v = e[i].v;
    		if(v == fa) continue;
    		dfs(v, x);
    	}
    }
    

    dfs预处理深度

    你的深度时你爸爸的深度 + 1


    void init()
    {
    	for(int j = 1; (1 << j) <= n; j++)
    		for(int i = 1; i <= n; i++)
    			if(deep[i] >= (1 << j)) 
    				dp[i][j] = dp[dp[i][j - 1]][j - 1];
    }
    

    倍增思想

    (i)的每(2^j)级祖先都预处理出来

    (2^j) = (2^{j-1}) + (2^{j-1})


    int lca(int x, int y)
    {
    	if(deep[x] < deep[y]) swap(x, y);
    	int d = deep[x] - deep[y];
    	for(int j = 20; j >= 0; j--) if(d & (1 << j)) x = dp[x][j];
    	if(x == y) return x;
    	for(int j = 20; j >= 0; j--)
    	{
    		if(dp[x][j] != dp[y][j])
    		{
    			x = dp[x][j];
    			y = dp[y][j];
    		}
    	}
    	return dp[x][0];
    }
    

    保持x更深,所以要交换

    不断将x向上提(j从大到小防止跳过头)

    如果x == y则说明两个点在一条线上,直接输出x

    否则继续同时将x和y向上提

    当他们的(2^j)级祖先相同了跳出循环

    输出dp[x][0]


    例题

    倍增例题

    P1613 跑路

    #include <cstdio>
    #include <algorithm>
    #include <cstring>
    #include <queue>
    #define maxn 10010
    #define maxm 100
    #define INF 1e9 + 7
    
    using namespace std;
    
    int head[maxm], cnt, dis[maxm], vis[maxm], s, n, m, ans, f[maxm][maxm][maxm], val[maxm][maxm];
    queue<pair<int, int> > q;
    
    struct node{
    	int v, next;
    }e[maxn];
    
    void add(int u, int v)
    {
    	e[++cnt].v = v;
    	e[cnt].next = head[u];
    	head[u] = cnt;
    }
    
    void dijkstra()
    {
    	for(int i = 1; i <= n; i++)	dis[i] = INF;
    	dis[s] = 0;
    	q.push(make_pair(0, s));
    	while(!q.empty())
    	{
    		int x = q.front().second;
    		q.pop();
    		if(vis[x]) continue;
    		vis[x] = 1;
    		for(int i = head[x]; i; i = e[i].next)
    		{
    			int v = e[i].v;
    			if(!vis[v] && dis[v] > dis[x] + val[x][v])
    			{
    				dis[v] = dis[x] + val[x][v];
    				q.push(make_pair(-dis[v], v));
    			}
    		}
    	}
    }
    
    int main() {
    	scanf("%d%d", &n, &m);
    	for(int i = 1; i <= m; i++)
    	{
    		int u, v;
    		scanf("%d%d", &u, &v);
    		f[u][v][0] = 1;
    	}
    	for(int k = 1; k <= 64; k++)
    		for(int j = 1; j <= n; j++)
    			for(int u = 1; u <= n; u++)
    				for(int v = 1; v <= n; v++)
    					if(f[u][j][k - 1] && f[j][v][k - 1]) f[u][v][k] = 1;
    	for(int u = 1; u <= n; u++)
    		for(int v = 1; v <= n; v++)
    			for(int k = 0; k <= 64; k++)
    				if(f[u][v][k])
    				{
    					val[u][v] = 1;
    					add(u, v);
    					break;
    				}
    	s = 1;	
    	dijkstra();
    	printf("%d", dis[n]);
    	return 0;
    } 
    

    这道题一看看上去好像直接跑最短路

    但显然路程最短并不是时间最短,所以要预处理

    这个跑路器1s可以跑(2^k)千米

    也就是说我们应该先将所有相距的(2^k)千米的点连起来,建成一个新图(用倍增)

    再新图上跑最短路,最后输出即可


    lca例题

    P1967 货车运输

    #include <cstdio>
    #include <cstring>
    #include <algorithm>
    #include <queue>
    #include <stack>
    #define maxn 10010
    #define INF 1e9 + 7
    
    using namespace std;
    
    int n, m, head[maxn], cnt, vis[maxn], deep[maxn], fa[maxn][21], w[maxn][21];
    int dad[maxn];
    
    struct node{
    	int v, next, val;
    }e[100005];
    
    struct mst{
    	int a, b, c;
    }arr[50005];
    
    void add(int u, int v, int val)
    {
    	e[++cnt].v = v;
    	e[cnt].val = val;
    	e[cnt].next = head[u];
    	head[u] = cnt;
    }
    
    void dfs(int x)
    {
    	vis[x] = 1;
    	for(int i = head[x]; i; i =e[i].next)
    	{
    		int v = e[i].v;
    		if(vis[v]) continue;
    		deep[v] = deep[x] + 1;
    		fa[v][0] = x;
    		w[v][0] = e[i].val;
    		dfs(v);
    	}
    	return ;
    }
    
    void init(){
    	for(int j = 1; j <= 20; j++)
    		for(int i = 1; i <= n; i++)
    			if(deep[i] >= (1 << j))
    			{
    				fa[i][j] = fa[fa[i][j - 1]][j - 1];
    				w[i][j] = min(w[i][j - 1], w[fa[i][j - 1]][j - 1]);
    			}
    }
    
    int find(int x)
    {
    	if(dad[x] != x) return dad[x] = find(dad[x]);
    	return x;
    }
    
    int cmp(const mst s1, const mst s2)
    {
    	return s1.c > s2.c; 
    }
    
    int lca(int x, int y)
    {
    	if(find(x) != find(y)) return -1;
    	int ans = INF;
    	if(deep[x] > deep[y]) swap(x, y);
    	for(int j = 20; j >= 0; j--)
    	{
    		if(deep[fa[y][j]] >= deep[x])
    		{
    			ans = min(ans, w[y][j]);
    			y = fa[y][j];
    		}
    	}
    	if(x == y) return ans;
    	for(int i = 20; i >= 0; i--)
    		if(fa[x][i] != fa[y][i])
    		{
    			ans = min(ans, min(w[x][i], w[y][i]));
    			x = fa[x][i];
    			y = fa[y][i];
    		}
    	ans = min(ans, min(w[x][0], w[y][0]));
    	return ans;
    }
    
    int main() {
    	scanf("%d%d", &n, &m);
    	for(int i = 1; i <= n; i++) dad[i] = i;
    	for(int i = 1; i <= m; i++)
    		scanf("%d%d%d", &arr[i].a, &arr[i].b, &arr[i].c);
    	sort(arr + 1, arr + 1 + m, cmp);
    	for(int i = 1; i <= m; i++)
    	{
    		int r1 = find(arr[i].a);
    		int r2 = find(arr[i].b);
    		if(r1 != r2)
    		{
    			dad[r2] = r1;
    			add(arr[i].a, arr[i].b, arr[i].c);
    			add(arr[i].b, arr[i].a, arr[i].c);
    		}
    	}
    	for(int i = 1; i <= n; i++)
    		if(!vis[i])
    		{
    			deep[i] = 1;
    			dfs(i);
    			fa[i][0] = i;
    			w[i][0] = INF;
    		}
    	init();
    	int q, x, y;
    	scanf("%d", &q);
    	for(int i = 1; i <= q; i++)
    	{
    		scanf("%d%d", &x, &y);
    		printf("%d
    ", lca(x, y));
    	}
    	return 0;
    }
    

    题目中说明可能有多条道路连接两个城市

    所以我们先用最大生成树使图中任意两个城市只有一条路并且权值最大

    接下来我们树上求lca即可


    本章节无关题(可用lca但最后没用)

    P1351 联合权值

    #include <cstdio>
    #include <algorithm>
    #include <cstring>
    #define maxn 200020
    typedef int int_;
    #define int long long
    const int mod = 10007;
    
    using namespace std;
    
    int n, head[maxn], val[maxn], sum, maxx,cnt;
    
    struct node{
    	int v, next;
    }e[maxn << 1];
    
    void add(int u, int v)
    {
    	e[++cnt].v = v;
    	e[cnt].next = head[u];
    	head[u] = cnt;
    }
    
    int_ main() {
    	scanf("%lld", &n);
    	int u, v;
    	for(int i = 1; i < n; i++)
    	{
    		scanf("%lld%lld", &u, &v);
    		add(u, v);
    		add(v, u);
    	}
    	for(int i = 1; i <= n; i++)
    		scanf("%lld", &val[i]);
    	int ttot, tmax, j;
    	for(int i = 1; i <= n; i++)
    	{
    		j = head[i];
    		tmax = val[e[j].v];
    		ttot = val[e[j].v] % mod;
    		j = e[j].next;
    		for(; j; j = e[j].next)
    		{
    			sum = (sum + ttot * val[e[j].v]) % mod;
    			maxx = max(maxx, tmax * val[e[j].v]);
    			ttot = (ttot + val[e[j].v]) % mod;
    			tmax = max(tmax, val[e[j].v]); 
    		}
    	}
    	printf("%lld %lld", maxx, (sum * 2) % mod);
    	return 0; 
    }
    

    两点之间距离为2则其中间必有一个中转点

    枚举每个中转点,用乘法结合律维护最大值与和

    最后因为两点间有序,所以和要乘2

  • 相关阅读:
    其它 Surface 重装系统 win10
    电商 商品数据分析 市场洞察 导出数据后 横线对比 python实现2
    电商 商品数据分析 市场洞察 导出数据后 横线对比 python实现
    电商 商品数据分析 市场洞察 导出数据后 横线对比
    Python excel转换为json
    关于四舍五入
    MBProgressHUD 显示后,为何不能点击屏幕其他地方
    使用容器挂载NFS
    luogu P1128 [HNOI2001]求正整数 dp 高精度
    EC R 86 D Multiple Testcases 构造 贪心 二分
  • 原文地址:https://www.cnblogs.com/wyswyz/p/11280128.html
Copyright © 2011-2022 走看看