zoukankan      html  css  js  c++  java
  • 倍增lca学习笔记

    倍增lca学习笔记

    前置

    倍增

    例题P3865 【模板】ST表

    #include <cstdio>
    #include <algorithm>
    #include <cstring>
    #define maxn 500010
    
    using namespace std;
    
    int n, m, st[maxn][21], llog2[maxn];
    
    void init()
    {
    	for(int j = 1; (1 << j) <= n; j++)
    		for(int i = 1; i + (1 << j) - 1 <= n; i++)
    			st[i][j] = max(st[i][j - 1], st[i + (1 << (j - 1))][j - 1]);
    	for(int i = 2; i <= n; i++)
    	{
    		llog2[i] = llog2[i - 1];
    		if((1 << (llog2[i] + 1)) == i) llog2[i]++;
    	}
    }
    
    int query(int l, int r)
    {
    	int k = llog2[r - l + 1];
    	return max(st[l][k], st[r - (1 << k) + 1][k]);
    }
    
    int main() {
    	int l, r;
    	scanf("%d%d", &n, &m);
    	for(int i = 1; i <= n; i++) scanf("%d", &st[i][0]);
    	init();
    	while(m--)
    	{
    		scanf("%d%d", &l, &r);
    		printf("%d
    ", query(l, r));
    	}
    	return 0;
    }
    

    st[i][j]表示从 i 开始往后 $ 2 ^ j$个数里的最大值

    llog[maxn]是为了快速计算log

    init初始化时,用类似于递推的方法预处理好所有(2 ^ j)的长度的最大值

    并且处理好(i)对应的log值(log手动模拟一下就懂)

    查询时数据不一定刚好是(2^j),所以我们为了要让st的端点刚好覆盖查询点

    用log便可以完成这个操作,虽然最后查询时的两段st会有重叠部分,但对答案没有影响

    LCA

    有tarjan和倍增两种写法(我知道的)

    tarjan可以看看同校巨巨写的

    模板P3379 【模板】最近公共祖先(LCA)

    #include <cstdio>
    #include <algorithm>
    #include <cstring>
    #define maxn 600005
    
    using namespace std;
    
    int n, m, cnt, rt, dp[maxn][21], head[maxn], deep[maxn];
    
    struct Edge{
    	int v, next;
    }e[maxn << 1];
    
    void add(int u, int v)
    {
    	e[++cnt].v = v;
    	e[cnt].next = head[u];
    	head[u] = cnt;
    }
    
    void dfs(int x, int fa)
    {
    	dp[x][0] = fa;
    	deep[x] = deep[fa] + 1;
    	for(int i = head[x]; i; i = e[i].next)
    	{
    		int v = e[i].v;
    		if(v == fa) continue;
    		dfs(v, x);
    	}
    }
    
    void init()
    {
    	for(int j = 1; (1 << j) <= n; j++)
    		for(int i = 1; i <= n; i++)
    			if(deep[i] >= (1 << j)) 
    				dp[i][j] = dp[dp[i][j - 1]][j - 1];
    }
    
    int lca(int x, int y)
    {
    	if(deep[x] < deep[y]) swap(x, y);
    	int d = deep[x] - deep[y];
    	for(int j = 20; j >= 0; j--) if(d & (1 << j)) x = dp[x][j];
    	if(x == y) return x;
    	for(int j = 20; j >= 0; j--)
    	{
    		if(dp[x][j] != dp[y][j])
    		{
    			x = dp[x][j];
    			y = dp[y][j];
    		}
    	}
    	return dp[x][0];
    }
    
    int main() {
    	int u, v;
    	scanf("%d%d%d", &n, &m, &rt);
    	for(int i = 1; i < n; i++)
    	{
    		scanf("%d%d", &u, &v);
    		add(u, v);
    		add(v, u);
    	}
    	dfs(rt, 0);
    	init();
    	for(int i = 1; i <= m; i++)
    	{
    		scanf("%d%d", &u, &v);
    		printf("%d
    ", lca(u, v));
    	}
    	return 0;
    }
    

    dp[i][j](i)(2^j)级祖先

    链式前向星建边不说了


    void dfs(int x, int fa)
    {
    	dp[x][0] = fa;
    	deep[x] = deep[fa] + 1;
    	for(int i = head[x]; i; i = e[i].next)
    	{
    		int v = e[i].v;
    		if(v == fa) continue;
    		dfs(v, x);
    	}
    }
    

    dfs预处理深度

    你的深度时你爸爸的深度 + 1


    void init()
    {
    	for(int j = 1; (1 << j) <= n; j++)
    		for(int i = 1; i <= n; i++)
    			if(deep[i] >= (1 << j)) 
    				dp[i][j] = dp[dp[i][j - 1]][j - 1];
    }
    

    倍增思想

    (i)的每(2^j)级祖先都预处理出来

    (2^j) = (2^{j-1}) + (2^{j-1})


    int lca(int x, int y)
    {
    	if(deep[x] < deep[y]) swap(x, y);
    	int d = deep[x] - deep[y];
    	for(int j = 20; j >= 0; j--) if(d & (1 << j)) x = dp[x][j];
    	if(x == y) return x;
    	for(int j = 20; j >= 0; j--)
    	{
    		if(dp[x][j] != dp[y][j])
    		{
    			x = dp[x][j];
    			y = dp[y][j];
    		}
    	}
    	return dp[x][0];
    }
    

    保持x更深,所以要交换

    不断将x向上提(j从大到小防止跳过头)

    如果x == y则说明两个点在一条线上,直接输出x

    否则继续同时将x和y向上提

    当他们的(2^j)级祖先相同了跳出循环

    输出dp[x][0]


    例题

    倍增例题

    P1613 跑路

    #include <cstdio>
    #include <algorithm>
    #include <cstring>
    #include <queue>
    #define maxn 10010
    #define maxm 100
    #define INF 1e9 + 7
    
    using namespace std;
    
    int head[maxm], cnt, dis[maxm], vis[maxm], s, n, m, ans, f[maxm][maxm][maxm], val[maxm][maxm];
    queue<pair<int, int> > q;
    
    struct node{
    	int v, next;
    }e[maxn];
    
    void add(int u, int v)
    {
    	e[++cnt].v = v;
    	e[cnt].next = head[u];
    	head[u] = cnt;
    }
    
    void dijkstra()
    {
    	for(int i = 1; i <= n; i++)	dis[i] = INF;
    	dis[s] = 0;
    	q.push(make_pair(0, s));
    	while(!q.empty())
    	{
    		int x = q.front().second;
    		q.pop();
    		if(vis[x]) continue;
    		vis[x] = 1;
    		for(int i = head[x]; i; i = e[i].next)
    		{
    			int v = e[i].v;
    			if(!vis[v] && dis[v] > dis[x] + val[x][v])
    			{
    				dis[v] = dis[x] + val[x][v];
    				q.push(make_pair(-dis[v], v));
    			}
    		}
    	}
    }
    
    int main() {
    	scanf("%d%d", &n, &m);
    	for(int i = 1; i <= m; i++)
    	{
    		int u, v;
    		scanf("%d%d", &u, &v);
    		f[u][v][0] = 1;
    	}
    	for(int k = 1; k <= 64; k++)
    		for(int j = 1; j <= n; j++)
    			for(int u = 1; u <= n; u++)
    				for(int v = 1; v <= n; v++)
    					if(f[u][j][k - 1] && f[j][v][k - 1]) f[u][v][k] = 1;
    	for(int u = 1; u <= n; u++)
    		for(int v = 1; v <= n; v++)
    			for(int k = 0; k <= 64; k++)
    				if(f[u][v][k])
    				{
    					val[u][v] = 1;
    					add(u, v);
    					break;
    				}
    	s = 1;	
    	dijkstra();
    	printf("%d", dis[n]);
    	return 0;
    } 
    

    这道题一看看上去好像直接跑最短路

    但显然路程最短并不是时间最短,所以要预处理

    这个跑路器1s可以跑(2^k)千米

    也就是说我们应该先将所有相距的(2^k)千米的点连起来,建成一个新图(用倍增)

    再新图上跑最短路,最后输出即可


    lca例题

    P1967 货车运输

    #include <cstdio>
    #include <cstring>
    #include <algorithm>
    #include <queue>
    #include <stack>
    #define maxn 10010
    #define INF 1e9 + 7
    
    using namespace std;
    
    int n, m, head[maxn], cnt, vis[maxn], deep[maxn], fa[maxn][21], w[maxn][21];
    int dad[maxn];
    
    struct node{
    	int v, next, val;
    }e[100005];
    
    struct mst{
    	int a, b, c;
    }arr[50005];
    
    void add(int u, int v, int val)
    {
    	e[++cnt].v = v;
    	e[cnt].val = val;
    	e[cnt].next = head[u];
    	head[u] = cnt;
    }
    
    void dfs(int x)
    {
    	vis[x] = 1;
    	for(int i = head[x]; i; i =e[i].next)
    	{
    		int v = e[i].v;
    		if(vis[v]) continue;
    		deep[v] = deep[x] + 1;
    		fa[v][0] = x;
    		w[v][0] = e[i].val;
    		dfs(v);
    	}
    	return ;
    }
    
    void init(){
    	for(int j = 1; j <= 20; j++)
    		for(int i = 1; i <= n; i++)
    			if(deep[i] >= (1 << j))
    			{
    				fa[i][j] = fa[fa[i][j - 1]][j - 1];
    				w[i][j] = min(w[i][j - 1], w[fa[i][j - 1]][j - 1]);
    			}
    }
    
    int find(int x)
    {
    	if(dad[x] != x) return dad[x] = find(dad[x]);
    	return x;
    }
    
    int cmp(const mst s1, const mst s2)
    {
    	return s1.c > s2.c; 
    }
    
    int lca(int x, int y)
    {
    	if(find(x) != find(y)) return -1;
    	int ans = INF;
    	if(deep[x] > deep[y]) swap(x, y);
    	for(int j = 20; j >= 0; j--)
    	{
    		if(deep[fa[y][j]] >= deep[x])
    		{
    			ans = min(ans, w[y][j]);
    			y = fa[y][j];
    		}
    	}
    	if(x == y) return ans;
    	for(int i = 20; i >= 0; i--)
    		if(fa[x][i] != fa[y][i])
    		{
    			ans = min(ans, min(w[x][i], w[y][i]));
    			x = fa[x][i];
    			y = fa[y][i];
    		}
    	ans = min(ans, min(w[x][0], w[y][0]));
    	return ans;
    }
    
    int main() {
    	scanf("%d%d", &n, &m);
    	for(int i = 1; i <= n; i++) dad[i] = i;
    	for(int i = 1; i <= m; i++)
    		scanf("%d%d%d", &arr[i].a, &arr[i].b, &arr[i].c);
    	sort(arr + 1, arr + 1 + m, cmp);
    	for(int i = 1; i <= m; i++)
    	{
    		int r1 = find(arr[i].a);
    		int r2 = find(arr[i].b);
    		if(r1 != r2)
    		{
    			dad[r2] = r1;
    			add(arr[i].a, arr[i].b, arr[i].c);
    			add(arr[i].b, arr[i].a, arr[i].c);
    		}
    	}
    	for(int i = 1; i <= n; i++)
    		if(!vis[i])
    		{
    			deep[i] = 1;
    			dfs(i);
    			fa[i][0] = i;
    			w[i][0] = INF;
    		}
    	init();
    	int q, x, y;
    	scanf("%d", &q);
    	for(int i = 1; i <= q; i++)
    	{
    		scanf("%d%d", &x, &y);
    		printf("%d
    ", lca(x, y));
    	}
    	return 0;
    }
    

    题目中说明可能有多条道路连接两个城市

    所以我们先用最大生成树使图中任意两个城市只有一条路并且权值最大

    接下来我们树上求lca即可


    本章节无关题(可用lca但最后没用)

    P1351 联合权值

    #include <cstdio>
    #include <algorithm>
    #include <cstring>
    #define maxn 200020
    typedef int int_;
    #define int long long
    const int mod = 10007;
    
    using namespace std;
    
    int n, head[maxn], val[maxn], sum, maxx,cnt;
    
    struct node{
    	int v, next;
    }e[maxn << 1];
    
    void add(int u, int v)
    {
    	e[++cnt].v = v;
    	e[cnt].next = head[u];
    	head[u] = cnt;
    }
    
    int_ main() {
    	scanf("%lld", &n);
    	int u, v;
    	for(int i = 1; i < n; i++)
    	{
    		scanf("%lld%lld", &u, &v);
    		add(u, v);
    		add(v, u);
    	}
    	for(int i = 1; i <= n; i++)
    		scanf("%lld", &val[i]);
    	int ttot, tmax, j;
    	for(int i = 1; i <= n; i++)
    	{
    		j = head[i];
    		tmax = val[e[j].v];
    		ttot = val[e[j].v] % mod;
    		j = e[j].next;
    		for(; j; j = e[j].next)
    		{
    			sum = (sum + ttot * val[e[j].v]) % mod;
    			maxx = max(maxx, tmax * val[e[j].v]);
    			ttot = (ttot + val[e[j].v]) % mod;
    			tmax = max(tmax, val[e[j].v]); 
    		}
    	}
    	printf("%lld %lld", maxx, (sum * 2) % mod);
    	return 0; 
    }
    

    两点之间距离为2则其中间必有一个中转点

    枚举每个中转点,用乘法结合律维护最大值与和

    最后因为两点间有序,所以和要乘2

  • 相关阅读:
    开源项目之Android StandOut(浮动窗口)
    小智慧7
    安卓学习
    asp.net学习Request和Response的其他属性
    bash中的转义
    POJ 1833 排列
    Django点滴(四)ORM对象存取
    POJ 1681 Painter's Problem
    linux2.6.32在mini2440开发板上移植(21)之WebServer服务器移植
    [gkk传智]static与多态及向下向上转型,及多态调用总结
  • 原文地址:https://www.cnblogs.com/wyswyz/p/11280128.html
Copyright © 2011-2022 走看看