zoukankan      html  css  js  c++  java
  • lca最近公共祖先(模板)

    洛谷上的lca模板题——传送门

    1.tarjan求lca

    学了求lca的tarjan算法(离线),在洛谷上做模板题,结果后三个点超时。

    又把询问改成链式前向星,才ok。

    这个博客,tarjan分析的很详细。

    附代码——

    #include <cstdio>
    #include <cstring>
    
    const int maxn = 500001;
    
    int n, m, cnt, s, cns;
    int x, y, z[maxn];//z是x和y的lca 
    int f[maxn], head[maxn], from[maxn];
    bool vis[maxn];
    struct node
    {
        int to, next;
    }e[2 * maxn];
    struct Node
    {
        int to, next, num;
    }q[2 * maxn];
    
    inline int read()//读入优化 
    {
        int x = 0, f = 1;
        char ch = getchar();
        while(ch < '0' || ch > '9')
        {
            if(ch == '-') f = -1;
            ch = getchar();
        }
        while(ch >= '0' && ch <= '9')
        {
            x = x * 10 + ch - '0';
            ch = getchar();
        }
        return x * f;
    }
    
    inline void ask(int u, int v, int i)//储存待询问的结构体,也是链式前向星优化 
    {
        q[cns].num = i;//num表示第几次询问 
        q[cns].to = v;
        q[cns].next = from[u];
        from[u] = cns++;
    }
    
    inline void add(int u, int v)//
    {
        e[cnt].to = v;
        e[cnt].next = head[u];
        head[u] = cnt++;
    }
    
    inline int find(int a)
    {
        return a == f[a] ? a : f[a] = find(f[a]);//路径压缩优化 
    }
    
    /*inline void Union(int a, int b)
    {
        int fx = find(a), fy = find(b);
        if(fx == fy) return;
        f[fy] = fx;
    }*/
    
    inline void tarjan(int k)
    {
        int i, j;
        vis[k] = 1;
        f[k] = k;
        for(i = head[k]; i != -1; i = e[i].next)
         if(!vis[e[i].to])
         {
              tarjan(e[i].to);
              //Union(k, e[i].to);
              f[e[i].to] = k;
         }
        for(i = from[k]; i != -1; i = q[i].next)
         if(vis[q[i].to] == 1)
          z[q[i].num] = find(q[i].to);
    }
    
    int main()
    {
        int i, j, u, v;
        n = read();
        m = read();
        s = read();
        memset(head, -1, sizeof(head));
        memset(from, -1, sizeof(from));
        for(i = 1; i <= n - 1; i++)
        {
            u = read();
            v = read();
            add(u, v);//注意添加两遍 
            add(v, u);
        }
        for(i = 1; i <= m; i++)
        {
            x = read();
            y = read();
            ask(x, y, i);//两遍 
            ask(y, x, i);
        }
        tarjan(s);
        for(i = 1; i <= m; i++) printf("%d
    ", z[i]);
        return 0;
    }
    View Code

    进过培训,修改了代码

     1 # include <iostream>
     2 # include <cstdio>
     3 # include <cstring>
     4 # include <string>
     5 # include <cmath>
     6 # include <vector>
     7 # include <map>
     8 # include <queue>
     9 # include <cstdlib>
    10 # define MAXN 500001
    11 using namespace std;
    12 
    13 inline int get_num() {
    14     int k = 0, f = 1;
    15     char c = getchar();
    16     for(; !isdigit(c); c = getchar()) if(c == '-') f = -1;
    17     for(; isdigit(c); c = getchar()) k = k * 10 + c - '0';
    18     return k * f;
    19 }
    20 
    21 int n, m, s;
    22 int fa[MAXN], qx[MAXN], qy[MAXN], ans[MAXN], f[MAXN];
    23 vector <int> vec[MAXN], q[MAXN];
    24 
    25 inline int find(int x)
    26 {
    27     return x == fa[x] ? x : fa[x] = find(fa[x]);
    28 }
    29 
    30 inline void dfs(int u)
    31 {
    32     int i, v;
    33     fa[u] = u;
    34     for(i = 0; i < vec[u].size(); i++)
    35     {
    36         v = vec[u][i];
    37         if(f[u] != v) f[v] = u, dfs(v);
    38     }
    39     for(i = 0; i < q[u].size(); i++)
    40         if(f[v = u ^ qx[q[u][i]] ^ qy[q[u][i]]])
    41             ans[q[u][i]] = find(v);
    42     fa[u] = f[u];
    43 }
    44 
    45 int main()
    46 {
    47     int i, x, y;
    48     n = get_num();
    49     m = get_num();
    50     s = get_num();
    51     for(i = 1; i < n; i++)
    52     {
    53         x = get_num();
    54         y = get_num();
    55         vec[x].push_back(y);
    56         vec[y].push_back(x);
    57     }
    58     for(i = 1; i <= m; i++)
    59     {
    60         qx[i] = get_num();
    61         qy[i] = get_num();
    62         q[qx[i]].push_back(i);
    63         q[qy[i]].push_back(i);
    64     }
    65     dfs(s);
    66     for(i = 1; i <= m; i++) printf("%d
    ", ans[i]);
    67     return 0;
    68 }
    View Code

    其实上面两个代码有些重复运算,请手动把求lca的过程放到dfs上面(也就是遍历到这个节点就求lca,而不是遍历完再求)

    2.倍增求lca

    下面是求lca的倍增算法(在线)

    1. DFS预处理出所有节点的深度和父节点

    inline void dfs(int u)
    {
        int i;
        for(i=head[u];i!=-1;i=next[i])  
        {  
            if (!deep[to[i]])
            {            
                deep[to[i]] = deep[u]+1;
                p[to[i]][0] = u; //p[x][0]保存x的父节点为u;
                dfs(to[i]);
            }
        }
    }
    dfs预处理

    2. 初始各个点的2^j祖先是谁 ,其中 2^j (j =0...log(该点深度))倍祖先,1倍祖先就是父亲,2倍祖先是父亲的父亲......。

    void init()
    {
        int i,j;
        //p[i][j]表示i结点的第2^j祖先
        for(j=1;(1<<j)<=n;j++)
            for(i=1;i<=n;i++)
                if(p[i][j-1]!=-1)
                    p[i][j]=p[p[i][j-1]][j-1];//i的第2^j祖先就是i的第2^(j-1)祖先的第2^(j-1)祖先
    }
    初始化

    3.从深度大的节点上升至深度小的节点同层,如果此时两节点相同直接返回此节点,即lca。

    否则,利用倍增法找到最小深度的 p[a][j]!=p[b][j],此时他们的父亲p[a][0]即lca。

    int lca(int a,int b)//最近公共祖先
    {
        int i,j;
        if(deep[a]<deep[b])swap(a,b);
        for(i=0;(1<<i)<=deep[a];i++);
        i--;
        //使a,b两点的深度相同
        for(j=i;j>=0;j--)
            if(deep[a]-(1<<j)>=deep[b])
                a=p[a][j];
        if(a==b)return a;
        //倍增法,每次向上进深度2^j,找到最近公共祖先的子结点
        for(j=i;j>=0;j--)
        {
            if(p[a][j]!=-1&&p[a][j]!=p[b][j])
            {
                a=p[a][j];
                b=p[b][j];
            }
        }
        return p[a][0];
    }
    倍增求lca

    最后是完整代码,为了节约时间,就没有把p数组初始化为-1.

    #include <cstdio>
    #include <cstring>
    #include <iostream>
    
    const int maxn = 500001;
    int n, m, cnt, s;
    int next[2 * maxn], to[2 * maxn], head[2 * maxn], deep[maxn], p[maxn][21];
    
    inline void add(int x, int y)
    {
        to[cnt] = y;
        next[cnt] = head[x];
        head[x] = cnt++;
    }
    
    inline void dfs(int i)
    {
        int j;
        for(j = head[i]; j != -1; j = next[j])
         if(!deep[to[j]])
         {
             deep[to[j]] = deep[i] + 1;
             p[to[j]][0] = i;
             dfs(to[j]);
         }
    }
    
    inline void init()
    {
        int i, j;
        for(j = 1; (1 << j) <= n; j++)
         for(i = 1; i <= n; i++)
          p[i][j] = p[p[i][j - 1]][j - 1];
    }
    
    inline int lca(int a, int b)
    {
        int i, j;
        if(deep[a] < deep[b]) std::swap(a, b);
        for(i = 0; (1 << i) <= deep[a]; i++);
        i--;
        for(j = i; j >= 0; j--)
         if(deep[a] - (1 << j) >= deep[b])
          a = p[a][j];
        if(a == b) return a;
        for(j = i; j >= 0; j--)
         if(p[a][j] != p[b][j])
         {
             a = p[a][j];
             b = p[b][j];
         }
        return p[a][0];
    }
    
    int main()
    {
        int i, j, x, y;
        memset(head, -1, sizeof(head));
        scanf("%d %d %d", &n, &m, &s);
        for(i = 1; i <= n - 1; i++)
        {
            scanf("%d %d", &x, &y);
            add(x, y);
            add(y, x);
        }
        deep[s] = 1;
        dfs(s);
        init();
        for(i = 1; i <= m; i++)
        {
            scanf("%d %d", &x, &y);
            printf("%d
    ", lca(x, y));
        }
        return 0;
    }
    View Code

    经过培训,又改了改代码。

     1 # include <iostream>
     2 # include <cstdio>
     3 # include <cstring>
     4 # include <string>
     5 # include <cmath>
     6 # include <vector>
     7 # include <map>
     8 # include <queue>
     9 # include <cstdlib>
    10 # define MAXN 500001
    11 using namespace std;
    12 
    13 inline int get_num() {
    14     int k = 0, f = 1;
    15     char c = getchar();
    16     for(; !isdigit(c); c = getchar()) if(c == '-') f = -1;
    17     for(; isdigit(c); c = getchar()) k = k * 10 + c - '0';
    18     return k * f;
    19 }
    20 
    21 int n, m, s;
    22 int f[MAXN][25], deep[MAXN];
    23 vector <int> vec[MAXN];
    24 
    25 inline void dfs(int u)
    26 {
    27     int i, v;
    28     deep[u] = deep[f[u][0]] + 1;
    29     for(i = 0; f[u][i]; i++) f[u][i + 1] = f[f[u][i]][i];
    30     for(i = 0; i < vec[u].size(); i++)
    31     {
    32         v = vec[u][i];
    33         if(!deep[v]) f[v][0] = u, dfs(v);
    34     }
    35 }
    36 
    37 inline int lca(int x, int y)
    38 {
    39     int i;
    40     if(deep[x] < deep[y]) swap(x, y);
    41     for(i = 20; i >= 0; i--)
    42         if(deep[f[x][i]] >= deep[y])
    43             x = f[x][i];
    44     if(x == y) return x;
    45     for(i = 20; i >= 0; i--)
    46         if(f[x][i] != f[y][i])
    47             x = f[x][i], y = f[y][i];
    48     return f[x][0];
    49 }
    50 
    51 int main()
    52 {
    53     int i, x, y;
    54     n = get_num();
    55     m = get_num();
    56     s = get_num();
    57     for(i = 1; i < n; i++)
    58     {
    59         x = get_num();
    60         y = get_num();
    61         vec[x].push_back(y);
    62         vec[y].push_back(x);
    63     }
    64     dfs(s);
    65     for(i = 1; i <= m; i++)
    66     {
    67         scanf("%d %d", &x, &y);
    68         printf("%d
    ", lca(x, y));
    69     }
    70     return 0;
    71 }
    View Code

    3.树剖法求lca

     1 # include <iostream>
     2 # include <cstdio>
     3 # include <cstring>
     4 # include <string>
     5 # include <cmath>
     6 # include <vector>
     7 # include <map>
     8 # include <queue>
     9 # include <cstdlib>
    10 # define MAXN 500001
    11 using namespace std;
    12 
    13 inline int get_num() {
    14     int k = 0, f = 1;
    15     char c = getchar();
    16     for(; !isdigit(c); c = getchar()) if(c == '-') f = -1;
    17     for(; isdigit(c); c = getchar()) k = k * 10 + c - '0';
    18     return k * f;
    19 }
    20 
    21 int n, m, s;
    22 int f[MAXN], size[MAXN], top[MAXN], son[MAXN], deep[MAXN];
    23 vector <int> vec[MAXN];
    24 
    25 inline void dfs1(int u)
    26 {
    27     int i, v;
    28     size[u] = 1;
    29     deep[u] = deep[f[u]] + 1;
    30     for(i = 0; i < vec[u].size(); i++)
    31     {
    32         v = vec[u][i];
    33         if(!deep[v])
    34         {
    35             f[v] = u;
    36             dfs1(v);
    37             size[u] += size[v];
    38             if(size[son[u]] < size[v]) son[u] = v;
    39         }
    40     }
    41 }
    42 
    43 inline void dfs2(int u, int tp)
    44 {
    45     int i, v;
    46     top[u] = tp;
    47     if(!son[u]) return;
    48     dfs2(son[u], tp);
    49     for(i = 0; i < vec[u].size(); i++)
    50     {
    51         v = vec[u][i];
    52         if(v != son[u] && v != f[u]) dfs2(v, v);
    53     }
    54 }
    55 
    56 inline int lca(int x, int y)
    57 {
    58     while(top[x] != top[y])
    59     {
    60         if(deep[top[x]] < deep[top[y]]) swap(x, y);
    61         x = f[top[x]];
    62     }
    63     if(deep[x] > deep[y]) swap(x, y);
    64     return x;
    65 }
    66 
    67 int main()
    68 {
    69     int i, x, y;
    70     n = get_num();
    71     m = get_num();
    72     s = get_num();
    73     for(i = 1; i < n; i++)
    74     {
    75         x = get_num();
    76         y = get_num();
    77         vec[x].push_back(y);
    78         vec[y].push_back(x);
    79     }
    80     dfs1(s);
    81     dfs2(s, s);
    82     for(i = 1; i <= m; i++)
    83     {
    84         x = get_num();
    85         y = get_num();
    86         printf("%d
    ", lca(x, y));
    87     }
    88     return 0;
    89 }
    View Code
  • 相关阅读:
    Celery 分布式任务队列入门
    异步通信----WebSocket
    爬虫框架之scrapy
    《JavaScript 高级程序设计》第一章:简介
    NodeJS学习:环境变量
    cmd 与 bash 基础命令入门
    H5开发中的故障
    认识 var、let、const
    netsh & winsock & 对前端的影响
    scrollify
  • 原文地址:https://www.cnblogs.com/zhenghaotian/p/6658105.html
Copyright © 2011-2022 走看看