zoukankan      html  css  js  c++  java
  • 暑假集训 || LCA && RMQ

    LCA定义为对于一颗树 树上两个点的最近公共祖先

    一.Tarjan求LCA(离线方法

    https://blog.csdn.net/lw277232240/article/details/77017517

    二.倍增法求LCA

    void dfs(int u, int f)
    {
        for(int i = 1; i <= 18; i++)
            if(deep[u] >= (1<<i))
                fa[u][i] = fa[fa[u][i-1]][i-1];
        for(int i = head[u];i;i = nxt[i])
        {
            int v = l[i].t;
            if(v != f)
            {
                deep[v] = deep[u] + 1;
                dist[v] = dist[u] + l[i].d;
                fa[v][0] = u;
                dfs(v, u);
            }
        }
    }
    int lca(int x, int y)
    {
        if(deep[x] < deep[y])
            swap(x, y);
        int delta = deep[x] - deep[y];
        for(int i = 0; i <= 18; i++)
            if((1<<i) & delta)
                x = fa[x][i];
        for(int i = 18; i >= 0; i--)
            if(fa[x][i] != fa[y][i])
            {
                x = fa[x][i];
                y = fa[y][i];
            }
        if(x == y) return x;
        else return fa[x][0];
    }
    LL getdis(int x, int y)
    {
        int z = lca(x, y);
        return dist[x] + dist[y] - 2 * dist[z];
    }

    可以用来求一棵树上两点之间的最短距离

    例题:

    Gym 101808K 思路题

    题意:给一个有n个点,n条边的图,n为1e5,查询任两点间的最短距离

    思路:n个点n-1条边的话就是树,这个图就比树多了一条边,把这条边拿出来考虑

    任意两点间的最短路有两种情况,一是经过这条边,二是不经过

    建图的时候不加这一条边

    x和y的距离 经过这条边的话x->uu + ww + vv->y或x->vv + ww + uu->y

    不经过的话直接求即可

    #include <iostream>
    #include <cstdio>
    #include <cstring>
    #include <cmath>
    #include <algorithm>
    #include <queue>
    using namespace std;
    typedef long long LL;
    const int SZ = 200010;
    const int INF = 1e9+10;
    int head[SZ], nxt[SZ], tot = 0, deep[SZ], fa[SZ][20];
    int fab[SZ];
    LL dist[SZ];
    struct node
    {
        int t, d;
    }l[SZ];
    void build(int f, int t, int d)
    {
        l[++tot].t = t;
        l[tot].d = d;
        nxt[tot] = head[f];
        head[f] = tot;
    }
    int n;
    void dfs(int u, int f)
    {
        for(int i = 1; i <= 18; i++)
            if(deep[u] >= (1<<i))
                fa[u][i] = fa[fa[u][i-1]][i-1];
        for(int i = head[u];i;i = nxt[i])
        {
            int v = l[i].t;
            if(v != f)
            {
                deep[v] = deep[u] + 1;
                dist[v] = dist[u] + l[i].d;
                fa[v][0] = u;
                dfs(v, u);
            }
        }
    }
    int lca(int x, int y)
    {
        if(deep[x] < deep[y])
            swap(x, y);
        int delta = deep[x] - deep[y];
        for(int i = 0; i <= 18; i++)
            if((1<<i) & delta)
                x = fa[x][i];
        for(int i = 18; i >= 0; i--)
            if(fa[x][i] != fa[y][i])
            {
                x = fa[x][i];
                y = fa[y][i];
            }
        if(x == y) return x;
        else return fa[x][0];
    }
    LL getdis(int x, int y)
    {
        int z = lca(x, y);
        return dist[x] + dist[y] - 2 * dist[z];
    }
    void init()
    {
        memset(head, 0, sizeof(head));
        tot = 0;
        memset(deep, 0, sizeof(deep));
        memset(fa, 0, sizeof(fa));
        for(int i = 1; i <= n; i++) fab[i] = i;
    }
    int find(int x)
    {
        return x == fab[x] ? x : fab[x] = find(fab[x]);
    }
    int main()
    {
        int T;
        scanf("%d", &T);
        while(T--)
        {
            int q;
            scanf("%d %d", &n, &q);
            init();
            int uu, vv, ww;
            for(int i = 0; i < n; i++)
            {
                int u, v, w;
                scanf("%d %d %d", &u, &v, &w);
                int fu = find(u), fv = find(v);
                if(fu == fv) uu = u, vv = v, ww = w;
                else
                {
                    fab[fu] = fv;
                    build(u, v, w);
                    build(v, u, w);
                }
            }
            dist[1] = 0;
            dfs(1, 0);
            //printf("%d %d
    ", uu, vv);
            //for(int i = 1; i <= n; i++) printf("%lld ", dist[i]);
            while(q--)
            {
                int x, y;
                scanf("%d %d", &x, &y);
                LL dis1 = getdis(x, y);
                LL dis2 = min(getdis(x, uu) + ww + getdis(y, vv), getdis(x, vv) + ww + getdis(y, uu));
                printf("%lld
    ", min(dis1, dis2));
            }
        }
        return 0;
    }

    Gym 101810M

    题意:一棵树,每条边来回都可以获得不同的值,每条边只能去一次回一次,任意查询从s到t最多能获得多少值

    思路:发现能获得的值是整棵树上的权值 - 从t走最短路到s获得的值

    用dist[0][u]记录从根走到u总的值

    用dist[1][u]记录从u走到根总的值

    画个图推个式子就ok了

    #include <iostream>
    #include <cstdio>
    #include <cstring>
    #include <cmath>
    #include <algorithm>
    #include <queue>
    using namespace std;
    typedef long long LL;
    const int SZ = 400010;
    const int INF = 1e9+10;
    int head[SZ], nxt[SZ], tot = 0, deep[SZ], fa[SZ][20];
    int dist[2][SZ];
    struct node
    {
        int t, c1, c2;
    }l[SZ];
    void build(int f, int t, int c1, int c2)
    {
        l[++tot].t = t;
        l[tot].c1 = c1;
        l[tot].c2 = c2;
        nxt[tot] = head[f];
        head[f] = tot;
    }
    int n;
    void dfs(int u, int f)
    {
        for(int i = 1; i <= 18; i++)
            if(deep[u] >= (1<<i))
                fa[u][i] = fa[fa[u][i-1]][i-1];
        for(int i = head[u];i;i = nxt[i])
        {
            int v = l[i].t;
            if(v != f)
            {
                deep[v] = deep[u] + 1;
                dist[0][v] = dist[0][u] + l[i].c1;
                dist[1][v] = dist[1][u] + l[i].c2;
                fa[v][0] = u;
                dfs(v, u);
            }
        }
    }
    int lca(int x, int y)
    {
        if(deep[x] < deep[y])
            swap(x, y);
        int delta = deep[x] - deep[y];
        for(int i = 0; i <= 18; i++)
            if((1<<i) & delta)
                x = fa[x][i];
        for(int i = 18; i >= 0; i--)
            if(fa[x][i] != fa[y][i])
            {
                x = fa[x][i];
                y = fa[y][i];
            }
        if(x == y) return x;
        else return fa[x][0];
    }
    void init()
    {
        memset(head, 0, sizeof(head));
        tot = 0;
        memset(deep, 0, sizeof(deep));
        memset(fa, 0, sizeof(fa));
    }
    int main()
    {
        int T;
        scanf("%d", &T);
        while(T--)
        {
            scanf("%d", &n);
            init();
            int sum = 0;
            for(int i = 0; i < n-1; i++)
            {
                int u, v, w1, w2;
                scanf("%d %d %d %d", &u, &v, &w1, &w2);
                build(u, v, w1, w2);
                build(v, u, w2, w1);
                sum = sum + w1 + w2;
            }
            dist[0][1] = 0, dist[1][1] = 0;
            dfs(1, 0);
            int q;
            scanf("%d", &q);
            while(q--)
            {
                int x, y;
                scanf("%d %d", &x, &y);
                int z = lca(x, y);
                int ans = dist[1][y] - dist[1][z] + dist[0][x] - dist[0][z];
                printf("%d
    ", sum - ans);
            }
        }
        return 0;
    }

    RMQ:区间最值查询问题

    用f[i][j]表示 从a[i] 开始 往后2^j个数里面的 最大/最小 值或GCD

    void st(int n)
    {
        for(int i = 1; i <= n; i++)
            f[i][0] = a[i];
        for(int j = 1; (1 << j) <= n; j++)
            for(int i = 1; i + (1 << j) - 1 <= n; i++)
                f[i][j] = max(f[i][j-1], f[i + (1<<(j-1))][j-1]);
    }
    int RMQ(int l, int r)
    {
        int k = 0;
        while((1<<(k + 1) <= r - l + 1)) k++;
        return max(f[l][k], f[r - (1<<k) + 1][k]);
    }

    最小值把max改成min, GCD把max改成__gcd

    例题:

    POJ 2019 二维RMQ

    #include <iostream>
    #include <cstdio>
    #include <cstring>
    #include <cmath>
    #include <algorithm>
    #include <queue>
    using namespace std;
    const int SZ = 550;
    const int INF = 1e9+10;
    int a[SZ][SZ], mmax[SZ][SZ][15], mmin[SZ][SZ][15];
    void st(int n)
    {
        for(int i = 1; i <= n; i++)
            for(int j = 1; j <= n; j++)
                mmax[i][j][0] = mmin[i][j][0] = a[i][j];
        for(int j = 1; (1 << j) <= n; j++)
            for(int i = 1; i + (1 << j) - 1 <= n; i++)
                for(int k = 1; k <= n; k++)
                {
                    mmax[k][i][j] = max(mmax[k][i][j-1], mmax[k][i + (1<<(j-1))][j-1]);
                    mmin[k][i][j] = min(mmin[k][i][j-1], mmin[k][i + (1<<(j-1))][j-1]);
                }
    }
    int RMQ(int x, int l, int r, int b)
    {
        int k = 0;
        while((1<<(k + 1) <= r - l + 1)) k++;
        int ans_max = -INF, ans_min = INF;
        for(int i = x; i < x + b; i++)
        {
            ans_max = max(ans_max, max(mmax[i][l][k], mmax[i][r- (1<<k) + 1][k]));
            ans_min = min(ans_min, min(mmin[i][l][k], mmin[i][r- (1<<k) + 1][k]));
        }
        return (ans_max - ans_min);
    }
    int main()
    {
        int n, b, k;
        scanf("%d %d %d", &n, &b, &k);
        for(int i = 1; i <= n; i++)
            for(int j = 1; j <= n; j++)
                scanf("%d", &a[i][j]);
        st(n);
        for(int i = 0; i < k; i++)
        {
            int x, y;
            scanf("%d %d", &x, &y);
            int l = y, r = y + b - 1;
            printf("%d
    ", RMQ(x, l, r, b));
        }
        return 0;
    }

    HDU 5726

    题意:给一段序列,任意查询一个区间内所有数的GCD,以及有多少个区间的GCD数和它相同

    思路:RMQ+二分。。考场上有思路了但是没敢敲QAQ

    发现序列越长,GCD是不增的,于是对于每个数可以通过二分判断它往后多少个数,这一段里面拥有相同的GCD

    用map记录即可

    #include <iostream>
    #include <cstdio>
    #include <cstring>
    #include <cmath>
    #include <algorithm>
    #include <queue>
    #include <map>
    using namespace std;
    typedef long long LL;
    const int SZ = 100010;
    const int INF = 1e9+10;
    int a[SZ];
    int f[SZ][22];
    map<int, LL> mp;
    void st(int n)
    {
        for(int i = 1; i <= n; i++)
            f[i][0] = a[i];
        for(int j = 1; (1 << j) <= n; j++)
            for(int i = 1; i + (1 << j) - 1 <= n; i++)
            {
                f[i][j] = __gcd(f[i][j-1], f[i + (1<<(j-1))][j-1]);
            }
    }
    int RMQ(int l, int r)
    {
        int k = 0;
        while((1<<(k + 1) <= r - l + 1)) k++;
        return __gcd(f[l][k], f[r - (1<<k) + 1][k]);
    }
    int main()
    {
        int T, tt = 0;
        scanf("%d", &T);
        while(T--)
        {
            int n;
            scanf("%d", &n);
            mp.clear();
            for(int i = 1; i <= n; i++)
                for(int j = 1; j <= 18; j++)
                    f[i][j] = 1;
            for(int i = 1; i <= n; i++)
                scanf("%d", &a[i]);
            st(n);
            for(int i = 1; i <= n; i++)
            {
                int k = i;
                while(k <= n)
                {
                    int l = k, r = n;
                    while(l <= r)
                    {
                        int mid = (l + r + 1) >> 1;
                        if(RMQ(i, mid) < RMQ(i, k)) r = mid - 1;
                        else l = mid + 1;
                    }
                    mp[RMQ(i, k)] += LL(l - k);
                    k = l;
                }
            }
            int q;
            scanf("%d", &q);
            printf("Case #%d:
    ", ++tt);
            while(q--)
            {
                int x, y;
                scanf("%d %d", &x, &y);
                int ans = RMQ(x, y);
                printf("%d %lld
    ", ans, mp[ans]);
            }
        }
        return 0;
    }
  • 相关阅读:
    位图
    3. 资源管理(条款:13-17)
    70. Implement strStr() 与 KMP算法
    69. Letter Combinations of a Phone Number
    68. Longest Common Prefix
    67. Container With Most Water
    66. Regular Expression Matching
    65. Reverse Integer && Palindrome Number
    波浪理论
    MACD理解
  • 原文地址:https://www.cnblogs.com/pinkglightning/p/9520801.html
Copyright © 2011-2022 走看看