zoukankan      html  css  js  c++  java
  • LCA 最近公共祖先 (笔记、模板)

    求lca的方法大体有三种:
    1.dfs+RMQ(线段树 ST表什么的) 在线
    2.倍增 在线
    3.tarjan 离线

    ps:离线:所有查询全输入后一次解决
    在线:有一个查询输出一次
    以下模板题为 洛谷 P3379 【模板】最近公共祖先(LCA)

    1.首先dfs求出

    1>dfs遍历时经过的所有节点的位置
    2>每个节点第一次出现的位置
    3>每个节点的深度
    查询时先取出两个节点的位置求出这两个位置间的深度最小的节点
    这个节点就是lca
    code:

    //By Menteur_Hxy 2068ms
    #include<cstdio>
    #include<iostream>
    #include<cstring>
    using namespace std;
    
    const int INF=0x3f3f3f3f;
    const int MAX=500110;
    
    int n,qu,root,cnt;
    int head[MAX],ver[MAX*2],first[MAX],log[MAX*2],deep[MAX],f[MAX*2][21];
    
    struct edges{
        int to,next;
    }edge[MAX*2+5];
    
    void add(int x,int y) {
        edge[++cnt].next=head[x];
        edge[cnt].to=y;
        head[x]=cnt;
    }
    
    void dfs(int u,int pre) {
        ver[++cnt]=u;
        first[u]=cnt;
        for(int i=head[u];i;i=edge[i].next) {
            int v=edge[i].to;
            if(v!=pre) {
                deep[v]=deep[u]+1;
                dfs(v,u);
                ver[++cnt]=u;
            }
        }
    }
    
    void ST() {
        for(int i=1;i<=cnt;i++) f[i][0]=ver[i];
        log[1]=0,log[2]=1;
        for(int i=3;i<=cnt;i++) {
            log[i]=log[i-1];
            if(1<<log[i-1]+1==i) log[i]++;
        }
        int k=log[cnt];
        for(int j=1;j<=k;j++)
            for(int i=1;i+(1<<j)-1<=cnt;i++) {//一定注意位运算要加括号!!!
                if(deep[f[i][j-1]]<deep[f[i+(1<<j-1)][j-1]])
                    f[i][j]=f[i][j-1];
                else f[i][j]=f[i+(1<<j-1)][j-1];
            }
    
    }
    
    int RMQ(int l,int r) {
        int k=log[r-l+1];
    //  cout<<l<<":"<<deep[f[l][k]]<<" "<<r<<":"<<deep[f[r-(1<<k)+1][k]]<<endl;
        if(deep[f[l][k]]<deep[f[r-(1<<k)+1][k]])
            return f[l][k];
        else return f[r-(1<<k)+1][k];
    //  return min(f[l][k],f[r-1<<k+1][k]);
    }
    
    int ask(int x,int y) {
        x=first[x];y=first[y];
        if(x>y) swap(x,y);
        return RMQ(x,y);
    }
    
    int main() {
        scanf("%d %d %d",&n,&qu,&root); 
        for(int i=1;i<n;i++) {
            int a,b;
            scanf("%d %d",&a,&b);
            add(a,b);
            add(b,a);
        }
        cnt=0;
        dfs(root,-1);
    //  for(int i=1;i<=cnt;i++) cout<<ver[i]<<" ";
    //  cout<<endl; 
        ST();
        for(int i=1;i<=qu;i++) {
            int a,b;
            scanf("%d %d",&a,&b);
            printf("%d
    ",ask(a,b));
        }
        return 0;
    }

    2.倍增

    同样需要dfs,不过只需求出深度和每个节点的父亲
    现在设f[i][j] 表示i的第2^j个祖先 (eg:f[i][0] 即为i的父亲)
    所以显然有 f[i][j]=f[f[i][j-1]][j-1] (i的第2^j-1个祖先的第2^j-1个祖先是i的2^j个祖先)
    查询两节点时,先将深度较大的节点向上移动直到与另一个节点深度相同,判断此时两节点是否相同如果不同,就从大到小枚举尝试往上跳直到两节点父亲相同此时父亲就是lca

    code:

    //By Menteur_Hxy 1860ms
    #include<cstdio>
    #include<iostream>
    #include<cstdlib>
    #include<algorithm>
    #include<cstring>
    #include<cmath>
    #include<ctime>
    using namespace std;
    
    const int MAX=500010;
    
    int n,qu,root,cnt;
    int head[MAX],deep[MAX],f[MAX][20];
    /*
    f[i][j] i的第2^j个祖先
    */
    struct edges{
        int to,next;
    }edge[MAX*2];
    
    void add(int x,int y) {
        edge[++cnt].to=y;
        edge[cnt].next=head[x];
        head[x]=cnt;
    }
    
    void dfs_bz(int cur) {
        for(int i=head[cur];i;i=edge[i].next) {
            int v=edge[i].to;
            if(v!=f[cur][0]) { //判断!!! 
    //          cout<<cur<<":"<<v<<endl;
                f[v][0]=cur;
                deep[v]=deep[cur]+1;
                dfs_bz(v);
            }
        }
    }
    
    void init() {
        for(int j=1;(1<<j)<=n;j++)
            for(int i=1;i<=n;i++) 
                if(f[i][j-1]!=-1)
                    f[i][j]=f[f[i][j-1]][j-1];
    }
    
    int lca_bz(int x,int y) {
        if(deep[x]<deep[y]) swap(x,y);
        int d=deep[x]-deep[y];
        for(int i=0;d;i++,d>>=1) 
            if(d&1) x=f[x][i];
    
    //  for(int i=0;(1<<i)<=d;i++) 
    //      if((1<<i)&d) x=f[x][i];  这样写也行 
    
    //  cout<<x<<","<<y<<endl;
    //  cout<<deep[x]<<"-"<<deep[y]<<endl;
        if(x!=y) {
            for(int i=17;i>=0;i--) { //=0也算!!!
                if(f[x][i]!=f[y][i])
                    x=f[x][i],y=f[y][i];
            }
            return f[x][0];
        }
        else return x;
    }
    
    int main(){
        scanf("%d %d %d",&n,&qu,&root);
        for(int i=1;i<n;i++) {
            int a,b;
            scanf("%d %d",&a,&b);
            add(a,b);
            add(b,a);
        }
        f[root][0]=-1;
        dfs_bz(root);
        init();
        for(int i=1;i<=qu;i++) {
            int x,y;
            scanf("%d %d",&x,&y);
            printf("%d
    ",lca_bz(x,y));
        }
        return 0;
    }

    3.tarjan

    先记录查询的问题
    进行dfs,每次在节点(u)返回前查询与它构成问题的所有节点(v),如果其中有之前已经遍历过的节点(vis[v]=true),则这两个节点的lca为find(v),所有v都查过后,将它与它的父节点的集合合并,之后返回。

    code;

    //By Menteur_Hxy 912ms
    #include<cstdio>
    #include<iostream>
    #include<cstring>
    using namespace std;
    
    const int MAX=500010;
    
    int n,qu,root,cnt;
    int head[MAX],f[MAX],head_[MAX],vis[MAX],ans[MAX];
    
    struct edges{
        int to,next;
    }edge[MAX*2],edge_[MAX*2];
    
    void add(int x,int y) {
        edge[++cnt].to=y;
        edge[cnt].next=head[x];
        head[x]=cnt;
    }
    
    void add_(int x,int y) {
        edge_[++cnt].to=y;
        edge_[cnt].next=head_[x];
        head_[x]=cnt;
    }
    
    int find(int x) {
        return f[x]==x?x:f[x]=find(f[x]);
    }
    
    void tarjan(int u,int pre) {
        vis[u]=1;
        for(int i=head[u];i;i=edge[i].next) {
            int v=edge[i].to;
            if(v!=pre) {
                tarjan(v,u);
                int a=find(v),b=find(u);
                f[a]=b;
            }   
        }
        for(int i=head_[u];i;i=edge_[i].next) {
            int v=edge_[i].to;
            if(vis[v]) 
                ans[(i+1)/2]=find(v);
        }
    }
    
    int main() {
        scanf("%d %d %d",&n,&qu,&root);
        for(int i=1;i<n;i++) {
            f[i]=i;
            int a,b;
            scanf("%d %d",&a,&b);
            add(a,b);
            add(b,a);
        }
        f[n]=n;
        cnt=0;
        for(int i=1;i<=qu;i++) {
            int a,b;
            scanf("%d %d",&a,&b);
            add_(a,b);
            add_(b,a);
        }
        tarjan(root,-1);
        for(int i=1;i<=qu;i++) 
            printf("%d
    ",ans[i]);
        return 0;
    }
    版权声明:本文为博主原创文章,未经博主允许不得转载。 博主:https://www.cnblogs.com/Menteur-Hxy/
  • 相关阅读:
    设计模式01之 简单工厂模式(创建模式)
    UML系列05之 基本流程图
    UML系列04之 UML时序图
    UML系列03之 UML类图(二)
    UML系列02之 UML类图(一)
    LaTex in Markdown
    Ubuntu18.04 下的Gif录制工具
    Python3 与 C# 扩展之~基础衍生
    Python3 与 C# 扩展之~模块专栏
    Python3 与 C# 面向对象之~异常相关
  • 原文地址:https://www.cnblogs.com/Menteur-Hxy/p/9248014.html
Copyright © 2011-2022 走看看