zoukankan      html  css  js  c++  java
  • 图论--最近公共祖先问题(LCA)模板

    最近公共祖先问题(LCA)是求一颗树上的某两点距离他们最近的公共祖先节点,由于树的特性,树上两点之间路径是唯一的,所以对于很多处理关于树的路径问题的时候为了得知树两点的间的路径,LCA是几乎最有效的解法。

    首先是LCA的倍增算法。算法主体是依靠首先对整个树的预处理DFS,用来预处理出每个点的直接父节点,同时可以处理出每个点的深度和与根节点的距离,然后利用类似RMQ的思想处理出每个点的 2 的幂次的祖先节点,这就可以用 nlogn 的时间完成整个预处理的工作。然后每一次求两个点的LCA时只要对两个点深度经行考察,将深度深的那个利用倍增先爬到和浅的同一深度,然后一起一步一步爬直到爬到相同节点,就是LCA了。

    具体模板是从鹏神的模板小改来的。

    注释方便理解版:

     1 #include<stdio.h>
     2 #include<string.h>
     3 #include<algorithm>
     4 using namespace std;
     5 
     6 const int maxn=1e5+5;
     7 const int maxm=1e5+5;
     8 const int maxl=20;        //总点数的log范围,一般会开稍大一点
     9 
    10 int fa[maxl][maxn],dep[maxn],dis[maxn];        //fa[i][j]是j点向上(不包括自己)2**i 层的父节点,dep是某个点的深度(根节点深度为0),dis是节点到根节点的距离
    11 int head[maxn],point[maxm],nxt[maxm],val[maxm],size;
    12 int n;
    13 
    14 void init(){
    15     size=0;
    16     memset(head,-1,sizeof(head));
    17 }
    18 
    19 void add(int a,int b,int v){
    20     point[size]=b;
    21     val[size]=v;
    22     nxt[size]=head[a];
    23     head[a]=size++;
    24     point[size]=a;
    25     val[size]=v;
    26     nxt[size]=head[b];
    27     head[b]=size++;
    28 }
    29 
    30 void Dfs(int s,int pre,int d){        //传入当前节点标号,父亲节点标号,以及当前深度
    31     fa[0][s]=pre;                    //当前节点的上一层父节点是传入的父节点标号
    32     dep[s]=d;
    33     for(int i=head[s];~i;i=nxt[i]){
    34         int j=point[i];
    35         if(j==pre)continue;
    36         dis[j]=dis[s]+val[i];
    37         Dfs(j,s,d+1);
    38     }
    39 }
    40 
    41 void Pre(){
    42     dis[1]=0;
    43     Dfs(1,-1,0);
    44     for(int k=0;k+1<maxl;++k){        //类似RMQ的做法,处理出点向上2的幂次的祖先。
    45         for(int v=1;v<=n;++v){
    46             if(fa[k][v]<0)fa[k+1][v]=-1;
    47             else fa[k+1][v]=fa[k][fa[k][v]];    //处理出两倍距离的祖先
    48         }
    49     }
    50 }
    51 
    52 int Lca(int u,int v){
    53     if(dep[u]>dep[v])swap(u,v);        //定u为靠近根的点
    54     for(int k=maxl-1;k>=0;--k){
    55         if((dep[v]-dep[u])&(1<<k))    //根据层数差值的二进制向上找v的父亲
    56             v=fa[k][v];
    57     }
    58     if(u==v)return u;                //u为v的根
    59     for(int k=maxl-1;k>=0;--k){
    60         if(fa[k][u]!=fa[k][v]){        //保持在相等层数,同时上爬寻找相同父节点
    61             u=fa[k][u];
    62             v=fa[k][v];
    63         }
    64     }
    65     return fa[0][u];                //u离lca只差一步
    66 }

    木有注释版:

     1 #include<stdio.h>
     2 #include<string.h>
     3 #include<algorithm>
     4 using namespace std;
     5 
     6 const int maxn=1e5+5;
     7 const int maxm=1e5+5;
     8 const int maxl=20;
     9 
    10 int fa[maxl][maxn],dep[maxn],dis[maxn];
    11 int head[maxn],point[maxm],nxt[maxm],val[maxm],size;
    12 int n;
    13 
    14 void init(){
    15     size=0;
    16     memset(head,-1,sizeof(head));
    17 }
    18 
    19 void add(int a,int b,int v){
    20     point[size]=b;
    21     val[size]=v;
    22     nxt[size]=head[a];
    23     head[a]=size++;
    24     point[size]=a;
    25     val[size]=v;
    26     nxt[size]=head[b];
    27     head[b]=size++;
    28 }
    29 
    30 void Dfs(int s,int pre,int d){
    31     fa[0][s]=pre;
    32     dep[s]=d;
    33     for(int i=head[s];~i;i=nxt[i]){
    34         int j=point[i];
    35         if(j==pre)continue;
    36         dis[j]=dis[s]+val[i];
    37         Dfs(j,s,d+1);
    38     }
    39 }
    40 
    41 void Pre(){
    42     dis[1]=0;
    43     Dfs(1,-1,0);
    44     for(int k=0;k+1<maxl;++k){
    45         for(int v=1;v<=n;++v){
    46             if(fa[k][v]<0)fa[k+1][v]=-1;
    47             else fa[k+1][v]=fa[k][fa[k][v]];
    48         }
    49     }
    50 }
    51 
    52 int Lca(int u,int v){
    53     if(dep[u]>dep[v])swap(u,v);
    54     for(int k=maxl-1;k>=0;--k){
    55         if((dep[v]-dep[u])&(1<<k))
    56             v=fa[k][v];
    57     }
    58     if(u==v)return u;
    59     for(int k=maxl-1;k>=0;--k){
    60         if(fa[k][u]!=fa[k][v]){
    61             u=fa[k][u];
    62             v=fa[k][v];
    63         }
    64     }
    65     return fa[0][u];
    66 }

    静态树上路径求最小值:LCA倍增

     1 #include<bits/stdc++.h>
     2 using namespace std;
     3 
     4 const int maxn=1e6+5;
     5 const int maxm=2e6+5;
     6 const int maxl=22;
     7 const int INF = 0x3f3f3f3f;
     8 
     9 int fa[maxl][maxn],dep[maxn],dis[maxl][maxn];
    10 int head[maxn],point[maxm],nxt[maxm],val[maxm],size;
    11 int vis[maxn];
    12 int n,q,tmp=INF;
    13 
    14 void init(){
    15     size=0;
    16     memset(head,-1,sizeof(head));
    17     memset(vis,0,sizeof(vis));
    18 }
    19 
    20 void add(int a,int b){
    21     point[size]=b;
    22     nxt[size]=head[a];
    23     head[a]=size++;
    24     point[size]=a;
    25     nxt[size]=head[b];
    26     head[b]=size++;
    27 }
    28 
    29 void Dfs(int s,int pre,int d){
    30     fa[0][s]=pre;
    31     dis[0][s]=s;
    32     dep[s]=d;
    33     for(int i=head[s];~i;i=nxt[i]){
    34         int j=point[i];
    35         if(j==pre)continue;
    36         Dfs(j,s,d+1);
    37     }
    38 }
    39 
    40 void Pre(){
    41     Dfs(1,-1,0);
    42     for(int k=0;k+1<maxl;++k){
    43         for(int v=1;v<=n;++v){
    44             if(fa[k][v]<0)fa[k+1][v]=-1;
    45             else fa[k+1][v]=fa[k][fa[k][v]];
    46             if(fa[k][v]<0)dis[k+1][v]=dis[k][v];
    47             else dis[k+1][v]=min(dis[k][v],dis[k][fa[k][v]]);
    48         }
    49     }
    50 }
    51 
    52 int Lca(int u,int v){
    53     tmp = min( u, v );
    54     if(dep[u]>dep[v])swap(u,v);
    55     for(int k=maxl-1;k>=0;--k){
    56         if((dep[v]-dep[u])&(1<<k)){
    57             tmp = min( tmp, dis[k][v]);
    58             v=fa[k][v];
    59         }
    60     }
    61     tmp = min( tmp,v );
    62     if(u==v)return u;
    63     for(int k=maxl-1;k>=0;--k){
    64         if(fa[k][u]!=fa[k][v]){
    65             tmp=min(tmp,min(dis[k][u],dis[k][v]));
    66             u=fa[k][u],v=fa[k][v];
    67         }
    68     }
    69     tmp = min( tmp, min(u,v));
    70     tmp = min( tmp, fa[0][u]);
    71     return fa[0][u];
    72 }    
    73 //tmp即为u、v路径上的最小值

     

    离线Tarjan的做法主要是防止由于每个点对可能被询问多次,导致每次求都需要 logn 的时间,会超时,所以离线来一并处理所有的询问。

    Tarjan的做法是通过递归到最底层,然后开始不断递归回去合并并查集,这样就能够在访问完每个点之后赋值它有关切另一个点已经被访问过的询问。

    同样是鹏神的模板修改成自己的代码风格后的。

    注释版:

     1 #include<stdio.h>        //差不多要这些头文件
     2 #include<string.h>
     3 #include<vector>
     4 #include<algorithm>
     5 using namespace std;
     6 
     7 const int maxn=1e5+5;    //点数、边数、询问数
     8 const int maxm=2e5+5;
     9 const int maxq=1e4+5;
    10 
    11 int n;
    12 int head[maxn],nxt[maxm],point[maxm],val[maxm],size;
    13 int vis[maxn],fa[maxn],dep[maxn],dis[maxn];
    14 int ans[maxq];
    15 vector<pair<int,int> >v[maxn];    //记录询问、问题编号
    16 
    17 void init(){
    18     memset(head,-1,sizeof(head));
    19     size=0;
    20     memset(vis,0,sizeof(vis));
    21     for(int i=1;i<=n;++i){
    22         v[i].clear();
    23         fa[i]=i;
    24     }
    25     dis[1]=dep[1]=0;
    26 }
    27 
    28 void add(int a,int b,int v){
    29     point[size]=b;
    30     val[size]=v;
    31     nxt[size]=head[a];
    32     head[a]=size++;
    33     point[size]=a;
    34     val[size]=v;
    35     nxt[size]=head[b];
    36     head[b]=size++;
    37 }
    38 
    39 int find(int x){
    40     return x==fa[x]?x:fa[x]=find(fa[x]);
    41 }
    42 
    43 void Tarjan(int s,int pre){
    44     for(int i=head[s];~i;i=nxt[i]){
    45         int j=point[i];
    46         if(j!=pre){
    47             dis[j]=dis[s]+val[i];
    48             dep[j]=dep[s]+1;
    49             Tarjan(j,s);                //这里Tarjan的DPS操作必须在并查集合并之前,这样才能保证求lca的时候lca是每一小部分合并时的祖先节点,如果顺序交换,那么所有的查询都会得到 1 节点,就是错误的
    50             int x=find(j),y=find(s);
    51             if(x!=y)fa[x]=y;
    52         }
    53     }
    54     vis[s]=1;
    55     for(int i=0;i<v[s].size();++i){
    56         int j=v[s][i].first;
    57         if(vis[j]){
    58             int lca=find(j);
    59             int id=v[s][i].second;
    60             ans[id]=lca;            //这里视题目要求给答案赋值
    61         //    ans[id]=dep[s]+dep[j]-2*dep[lca];
    62         //    ans[id]=dis[s]+dis[j]-2*dis[lca];
    63         }
    64     }
    65 }
    66 
    67 
    68 
    69 for(int i=1;i<=k;++i){        //主函数中的主要部分
    70     int a,b;
    71     scanf("%d%d",&a,&b);
    72     v[a].push_back(make_pair(b,i));        //加问题的时候两个点都要加一次
    73     v[b].push_back(make_pair(a,i));
    74 }
    75 Tarjan(1,0);

    木有注释版:

     1 #include<stdio.h>
     2 #include<string.h>
     3 #include<vector>
     4 #include<algorithm>
     5 using namespace std;
     6 
     7 const int maxn=1e5+5;
     8 const int maxm=2e5+5;
     9 const int maxq=1e4+5;
    10 
    11 int n;
    12 int head[maxn],nxt[maxm],point[maxm],val[maxm],size;
    13 int vis[maxn],fa[maxn],dep[maxn],dis[maxn];
    14 int ans[maxq];
    15 vector<pair<int,int> >v[maxn];
    16 
    17 void init(){
    18     memset(head,-1,sizeof(head));
    19     size=0;
    20     memset(vis,0,sizeof(vis));
    21     for(int i=1;i<=n;++i){
    22         v[i].clear();
    23         fa[i]=i;
    24     }
    25     dis[1]=dep[1]=0;
    26 }
    27 
    28 void add(int a,int b,int v){
    29     point[size]=b;
    30     val[size]=v;
    31     nxt[size]=head[a];
    32     head[a]=size++;
    33     point[size]=a;
    34     val[size]=v;
    35     nxt[size]=head[b];
    36     head[b]=size++;
    37 }
    38 
    39 int find(int x){
    40     return x==fa[x]?x:fa[x]=find(fa[x]);
    41 }
    42 
    43 void Tarjan(int s,int pre){
    44     for(int i=head[s];~i;i=nxt[i]){
    45         int j=point[i];
    46         if(j!=pre){
    47             dis[j]=dis[s]+val[i];
    48             dep[j]=dep[s]+1;
    49             Tarjan(j,s);
    50             int x=find(j),y=find(s);
    51             if(x!=y)fa[x]=y;
    52         }
    53     }
    54     vis[s]=1;
    55     for(int i=0;i<v[s].size();++i){
    56         int j=v[s][i].first;
    57         if(vis[j]){
    58             int lca=find(j);
    59             int id=v[s][i].second;
    60             ans[id]=lca;
    61         //    ans[id]=dep[s]+dep[j]-2*dep[lca];
    62         //    ans[id]=dis[s]+dis[j]-2*dis[lca];
    63         }
    64     }
    65 }
    66 
    67 
    68 
    69 for(int i=1;i<=k;++i){
    70     int a,b;
    71     scanf("%d%d",&a,&b);
    72     v[a].push_back(make_pair(b,i));
    73     v[b].push_back(make_pair(a,i));
    74 }
    75 Tarjan(1,0);

    另外,现在又有LCA用dfs序+RMQ的做法,可以实现O(nlogn)预处理,O(1)查询的LCA,基本可以完全替代倍增LCA和TarjanLCA,但是树上路径长度和树上路径最小值无法用这个来做。

      1 #include <bits/stdc++.h>
      2 using namespace std;
      3 
      4 const int maxn = 2e5+5;
      5 const int maxl = 20;
      6 int vis[maxn],dep[maxn],dp[maxn][maxl];
      7 
      8 int head[maxn],in[maxn],id[maxn];
      9 int point[maxn],nxt[maxn],sz;
     10 int val[maxn];
     11 int fa[maxl][maxn];        //fa[i][j]是j点向上(不包括自己)2**i 层的父节点,dep是某个点的深度(根节点深度为0),dis是节点到根节点的距离
     12 int n;
     13 
     14 void init(){
     15     sz = 0;
     16     memset(head,-1,sizeof(head));
     17     memset(fa,-1,sizeof(fa));
     18 }
     19 
     20 void Pre(){
     21     for(int k=0;k+1<maxl;++k){        //类似RMQ的做法,处理出点向上2的幂次的祖先。
     22         for(int v=1;v<=n;++v){
     23             if(fa[k][v]<0)fa[k+1][v]=-1;
     24             else fa[k+1][v]=fa[k][fa[k][v]];    //处理出两倍距离的祖先
     25         }
     26     }
     27 }
     28 
     29 void dfs(int u,int p,int d,int&k){
     30     fa[0][u]=p;                    //当前节点的上一层父节点是传入的父节点标号
     31     vis[k] = u;
     32     id[u] = k;
     33     dep[k++]=d;
     34     for(int i = head[u];~i;i=nxt[i]){
     35         int v = point[i];
     36         if(v == p)continue;
     37         dfs(v,u,d+1,k);
     38         vis[k] = u;
     39         dep[k++]=d;
     40     }
     41 }
     42 
     43 void RMQ(int root){
     44     int k =0 ;
     45     dfs(root,-1,0,k);
     46     int m = k;
     47     int e= (int)(log2(m+1.0));
     48     for(int i = 0 ; i < m ; ++ i)dp[i][0]=i;
     49     for(int j = 1 ; j <= e ; ++ j){
     50         for(int i = 0 ; i + ( 1<< j ) - 1 < m ; ++ i){
     51             int N = i + (1<<(j-1));
     52             if(dep[dp[i][j-1]] < dep[dp[N][j-1]]){
     53                 dp[i][j] = dp[i][j-1];
     54             }
     55             else dp[i][j] = dp[N][j-1];
     56         }
     57     }
     58 }
     59 
     60 void add(int a,int b){
     61     point[sz] = b;
     62     nxt[sz] = head[a];
     63     head[a] = sz++;
     64 }
     65 
     66 
     67 int LCA(int u,int v){
     68     int left = min(id[u],id[v]),right = max(id[u],id[v]);
     69     int k = (int)(log2(right- left+1.0));
     70     int pos,N = right - (1<<k)+1;
     71     if(dep[dp[left][k]] < dep[dp[N][k]])pos = dp[left][k];
     72     else pos = dp[N][k];
     73     return vis[pos];
     74 }
     75 
     76 int q;
     77 
     78 inline int get(int a,int k){
     79     int res = a;
     80     for(int i = 0 ; (1ll << i ) <= k ; ++ i){
     81         if(k&(1ll<<i)){
     82             res = fa[i][res];
     83         }
     84     }
     85     return res;
     86 }
     87 
     88 void run(){
     89     while(q--){
     90         int a,b,k;
     91         scanf("%d%d%d",&a,&b,&k);
     92         int lca = LCA(a,b);
     93         int num = dep[id[a]] - dep[id[lca]] + dep[id[b]] - dep[id[lca]] + 1;
     94         int up = (num - 1)%k;
     95         int ans = val[a];
     96 //        printf("a : %d
    ",a);
     97         while(dep[id[a]] - dep[id[lca]] >= k){
     98             int Id = get(a,k);
     99             ans ^= val[Id];
    100         //    printf("a : %d
    ",Id);
    101             a = Id;
    102         }
    103         if(dep[id[b]] - dep[id[lca]] > up){
    104         //    printf("up: %d
    ",up);
    105                if(up == 0)ans^= val[b];
    106             b = get(b,up);
    107             while(dep[id[b]] - dep[id[lca]] >k ){
    108                 int Id = get(b,k);
    109                 ans ^= val[Id];
    110                 b = Id;
    111             }
    112         }
    113         printf("%d
    ",ans);
    114     }
    115 }
    116 
    117 int main(){
    118     while(scanf("%d%d",&n,&q)!=EOF){
    119 
    120         init();
    121         for(int i =1 ; i < n; ++ i){
    122             int a,b;
    123             scanf("%d%d",&a,&b);
    124             add(a,b);
    125             add(b,a);
    126         }
    127         for(int i = 1;i <= n ; ++ i)scanf("%d",&val[i]);
    128         RMQ(1);
    129         Pre();
    130         run();
    131 
    132     }
    133     return 0;
    134 }
  • 相关阅读:
    读写csv文件
    安卓跳转
    求时间精确到秒的数
    航空公司客户价值分析
    利用LM神经网络和决策树去分类
    拉格朗日插值法
    ID3
    K最近邻
    贝叶斯分类
    FilterDispatcher已被标注为过时解决办法
  • 原文地址:https://www.cnblogs.com/cenariusxz/p/4826875.html
Copyright © 2011-2022 走看看