zoukankan      html  css  js  c++  java
  • HDU6031 Innumerable Ancestors 倍增

    去博客园看该题解

    题目

    查看原题 - HDU6031 Innumerable Ancestors

    题目描述

      有一棵有n个节点的有根树,根节点为1,其深度为1,现在有m个询问,每次询问给出两个集合A和B,问LCA(x,y)(x∈A,y∈B)的深度最大为多少。

    输入描述

      有多组数据(数据组数<=5)

      对于每一组数据,首先2个数n,m,表示有根树的节点个数和询问个数。然后n-1行,每行2个数a,b表示节点a和节点b之间存在直接的连边;接下去2m行,每两行,分别描述当前询问的集合A和集合B;对于一个集合,用一行来描述,该行第一个数K表示集合元素的个数,后面K个数表示集合中的元素。

    输出描述

      一个整数,表示LCA(x,y)(x∈A,y∈B)的最大深度。

    数据范围

      n,m<=100000, 1<=a,b<=n, ΣK<=100000, 1<=集合中的元素<=n

     

    题解

      问最大深度,那么我们思考是否可以二分答案。

      当然可以,本题的条件满足二分答案的前提,LCA基本的性质还是比较明显的。(假设a和b深度一样)设anst[x][y]为节点x往上走y步到达的祖先,对于一个k,如果anst[a][k]==anst[b][k],那么对于k'(k'>k),一定有anst[a][k']==anst[b][k'];对于一个k,如果anst[a][k]!=anst[b][k],那么对于k'(k'<k),一定有anst[a][k']!=anst[b][k'],而且LCA(a,b)=LCA(anst[a][k],anst[b][k])。

      二分答案深度d完成之后,那么就剩下了编一个子程序判定的事情了。

      那么如果判定呢?

      已知祖先深度,那么就知道了每一个点所对应的祖先了是吧?那么,判断是否有公共祖先,其实就是判断A集合所对应的祖先集合与B集合所对应的祖先集合是否有交集——因为ΣK<=100000, 所以对于每一个集合元素找出它的某一深度的祖先这个复杂度貌似还是不够,ΣK*n应该会超(如果有人用ΣK*n的判定复杂度过了本题, 跪求留代码) 。那么我们要更快的找到这个祖先,那么是什么?倍增啊!

      fa[i][j]表示与节点i深度差为2^j的i的祖先,那么不难写出转移方程:

      fa[i][0]=father[i],fa[i][j]=fa[fa[i][j-1]][j-1] (father[i]表示节点i的父亲节点)

      So,求某一深度的祖先就是和倍增求LCA的前一半类似的了。

      至于两个集合判断交集,就是排个序,然后两个指针扫过去就可以了。

      注意: 在求祖先时,要首先把那些不合法的祖先过滤掉; 在判断交集的时候,要注意边界情况!

    代码

     

    #include <cstring>
    #include <cstdlib>
    #include <algorithm>
    #include <cstdio>
    #include <cmath>
    using namespace std;
    const int N=100005,M=N*2,rt=1;
    struct Edge{
        int cnt,y[M],nxt[M],fst[N];
        void set(){
            cnt=0;
            memset(y,0,sizeof y);
            memset(nxt,0,sizeof nxt);
            memset(fst,0,sizeof fst);
        }
        void add(int a,int b){
            y[++cnt]=b,nxt[cnt]=fst[a],fst[a]=cnt;
        }
    }e;
    int n,m,depth[N],fa[N][20],ta,a[N],tb,b[N],ansta[N],anstb[N];
    void build(int prev,int rt){
        fa[rt][0]=prev,depth[rt]=depth[prev]+1;
        for (int i=1;(1<<i)<=depth[rt];i++)
            fa[rt][i]=fa[fa[rt][i-1]][i-1];
        for (int i=e.fst[rt];i;i=e.nxt[i])
            if (e.y[i]!=prev)
                build(rt,e.y[i]);
    }
    int get_kth_anst(int p,int k){
        for (int i=k,j=0;i>0;i>>=1,j++)
            if (i&1)
                p=fa[p][j];
        return p;
    }
    bool check(int d){
        int at=0,bt=0;
        for (int i=1;i<=ta;i++)
            if (depth[a[i]]>=d)
                ansta[++at]=get_kth_anst(a[i],depth[a[i]]-d);
        for (int i=1;i<=tb;i++)
            if (depth[b[i]]>=d)
                anstb[++bt]=get_kth_anst(b[i],depth[b[i]]-d);
        if (at==0||bt==0)
            return 0;
        int pa=1,pb=1;
        sort(ansta+1,ansta+at+1);
        sort(anstb+1,anstb+bt+1);
        if (ansta[1]==anstb[1])
            return 1;
        while (pa<=at&&pb<=bt){
            while (pa<=at&&ansta[pa]<anstb[pb])
                pa++;
            if (pa>at)
                break;
            if (ansta[pa]==anstb[pb])
                return 1;
            while (pb<=bt&&ansta[pa]>anstb[pb])
                pb++;
            if (pb>bt)
                break;
            if (ansta[pa]==anstb[pb])
                return 1;
        }
        return 0;
    }
    int main(){
        while (~scanf("%d%d",&n,&m)){
            e.set();
            for (int i=1,a,b;i<n;i++)
                scanf("%d%d",&a,&b),e.add(a,b),e.add(b,a);
            depth[0]=-1;
            build(0,rt);
            while (m--){
                scanf("%d",&ta);
                for (int i=1;i<=ta;i++)
                    scanf("%d",&a[i]);
                scanf("%d",&tb);
                for (int i=1;i<=tb;i++)
                    scanf("%d",&b[i]);
                int le=0,ri=n-1,mid,ans=0;
                while (le<=ri){
                    mid=(le+ri)>>1;
                    if (check(mid))
                        le=mid+1,ans=mid;
                    else
                        ri=mid-1;
                }
                printf("%d
    ",ans+1);
            }
        }
        return 0;
    }
      

    为了方便大家找茬,特地附上一份造数据的PASCAL代码,用于对拍。

    var
        t, i: longint;
    function min(a, b: longint): longint;
        begin
            if (a > b) then
                exit(b);
            exit(a);
        end;
    procedure make_list(n ,m: longint);
        var
            i, j: longint;
        begin
            write(m, ' ');
            j := 0;
            for i := 1 to m do 
            begin
                j := j + random(n - j - m + i) + 1;
                write(j, ' ');
            end;
            writeln;
        end;
    procedure mkdata;
        const 
            maxn = 150;
            maxm = 150;
            add = 40;
        var
            n, m, i, j, x, y, a, b: longint;
        begin
            n := random(maxn) + 1;
            m := random(maxm) + 1;
            writeln(n, ' ', m);
            for i := 2 to n do 
            begin
                x := i;
                y := random(i - 1) + 1;
                if (random(2) = 1) then
                    writeln(x, ' ', y)
                else
                    writeln(y, ' ', x);
            end;
            writeln;
            for i := 1 to m do 
            begin
                a := min(random(maxn div m + add)+2, n);
                b := min(random(maxn div m + add)+2, n);
                make_list(n, a);
                make_list(n, b);
            end;
            writeln;
        end;
    begin
        assign(output, 'anst.in');
        rewrite(output);
        randomize;
        t := random(2) + 1;
        for i := 1 to t do
            mkdata;
        close(output);
    end.
  • 相关阅读:
    HTML入门之003
    html入门之002
    HTML入门之001
    端口
    计算机基础
    二进制的学习
    markdown基础
    css基础
    html基础之三
    html基础之二
  • 原文地址:https://www.cnblogs.com/zhouzhendong/p/HDU6031.html
Copyright © 2011-2022 走看看