zoukankan      html  css  js  c++  java
  • 2018青岛网络预选赛 B."Red Black Tree"(LCA+二分答案)

    传送门

    •参考资料

      [1]:ACM-ICPC 2018 青岛赛区网络预赛 B. Red Black Tree (LCA、二分)

    •题意 

      给出一棵树,根节点为1。

      每条边有一个权值,树上有红色结点 m 个,其花费为 0 ,其余为黑色;

      每个黑色结点的花费为其到最近红色祖先的经过的路径权值之和。

      有 q 次询问,每次给出一个点集;

      问将树上任意一个结点涂成红色结点后,点集中所有点的花费的最大值的最小是多少。

    •题解

      相关变量解释:

        sum : 每次询问中询问的点集个数

        a[  ]  : 存储每次询问到的点集

        costR[i] : 结点 i 距其最近红色祖先的花费

      预处理每个点到根的距离cost、到最近红色祖先的距离 costR 和 ST 表。

      对于每次询问,将a[ ] 按 costR 从大到小排序,在 0~costR[a[0]] 范围内二分答案;

      对所有大于答案的点求它们的公共祖先(利用ST表可以O(1)求两点的公共祖先),将其涂红;

      之后计算每个大于答案的点的新花费是否小于答案。

    •Code

      1 #include<iostream>
      2 #include<vector>
      3 #include<cstdio>
      4 #include<cmath>
      5 #include<algorithm>
      6 #include<cstring>
      7 using namespace std;
      8 #define pb push_back
      9 #define ll long long
     10 #define mem(a,b) (memset(a,b,sizeof a))
     11 const int maxn=1e5+50;
     12 
     13 int n,m,q;
     14 //===============Restore Graph============
     15 struct Node
     16 {
     17     int to;
     18     ll w;
     19     Node(int to,int w):to(to),w(w){}
     20 };
     21 vector<Node >G[maxn];
     22 void addEdge(int u,int v,int w)
     23 {
     24     G[u].pb(Node(v,w));
     25     G[v].pb(Node(u,w));
     26 }
     27 //=========================================
     28 int vs[2*maxn];//欧拉序列,范围区间为 [1,total]
     29 int depth[2*maxn];//欧拉序列对应的深度序列
     30 int pos[maxn];//pos[i] : 结点 i 再欧拉序列中第一次出现的位置
     31 ll cost[maxn];//cost[i] : 结点 i 距根据点的距离
     32 ll costR[maxn];//costR[i] : 结点 i 距最近红色祖先结点的距离,初始化为 -1
     33 int total;//欧拉序列的大小
     34 void dfs(int u,int f,int dep,ll dis)
     35 {
     36     vs[++total]=u;
     37     depth[total]=dep;
     38     pos[u]=total;
     39     cost[u]=dis;
     40     for(int i=0;i < G[u].size();++i)
     41     {
     42         Node e=G[u][i];
     43         if (e.to == f)
     44             continue;
     45         costR[e.to]=(costR[e.to] == 0 ? 0:costR[u]+e.w);
     46         dfs(e.to,u,dep+1,dis+e.w);
     47         vs[++total]=u;
     48         depth[total]=dep;
     49     }
     50 }
     51 //==================RMQ======================
     52 struct Node2
     53 {
     54     int mm[2 * maxn];
     55     int dp[2 * maxn][20];
     56     void ST()
     57     {
     58         int n=total;
     59         mm[0] = -1;
     60         for (int i = 1; i <= n; i++)
     61         {
     62             mm[i]=((i&(i-1))==0) ? mm[i - 1] + 1:mm[i - 1];
     63             dp[i][0]=i;
     64         }
     65         for (int j=1;j <= mm[n];j++)
     66             for (int i=1;i+(1<<j)-1 <= n;i++)
     67                 if(depth[dp[i][j - 1]] < depth[dp[i+(1<<(j-1))][j-1]])
     68                     dp[i][j]=dp[i][j-1];
     69                 else
     70                     dp[i][j]=dp[i+(1<<(j-1))][j-1];
     71     }
     72     int Lca(int u, int v)
     73     {
     74         u=pos[u],v=pos[v];
     75         if (u > v)
     76             swap(u, v);
     77         int k = mm[v-u+1];
     78         if(depth[dp[u][k]] <= depth[dp[v-(1<<k)+1][k]])
     79             return vs[dp[u][k]];
     80         return vs[dp[v-(1<<k)+1][k]];
     81     }
     82 }_rmq;
     83 //==========================================
     84 int a[maxn];
     85 int sum;
     86 bool cmp(int a, int b)
     87 {
     88     return costR[a] > costR[b];
     89 }
     90 bool Check(ll x)
     91 {
     92     if(costR[a[0]] <= x)
     93         return true;
     94     int lca=a[0];
     95     for(int i=1;i < sum;i++)
     96     {
     97         if(costR[a[i]] <= x)
     98             break;
     99         lca=_rmq.Lca(lca,a[i]);
    100     }
    101     for(int i = 0;i < sum;i++)
    102     {
    103         if(costR[a[i]] <= x)
    104             return true;
    105         if(cost[a[i]]-cost[lca] > x)
    106             return false;
    107     }
    108     return true;
    109 }
    110 void Solve()
    111 {
    112     dfs(1,-1,0,0);
    113     _rmq.ST();
    114     while(q--)
    115     {
    116         scanf("%d",&sum);
    117         for (int i=0;i < sum; i++)
    118             scanf("%d",&a[i]);
    119         sort(a,a+sum,cmp);
    120         ll l=0,r=costR[a[0]];
    121         while(l < r)
    122         {
    123             ll mid=(l+r)/2;
    124             if(Check(mid))
    125                 r=mid;
    126             else
    127                 l=mid + 1;
    128         }
    129         printf("%lld
    ",l);
    130     }
    131 }
    132 void init()
    133 {
    134     mem(costR,-1);
    135     total=0;
    136     for(int i=0;i < maxn;++i)
    137         G[i].clear();
    138 }
    139 int main()
    140 {
    141     int t;
    142     scanf("%d", &t);
    143     while(t--)
    144     {
    145         init();
    146         scanf("%d%d%d",&n,&m,&q);
    147         while(m--)
    148         {
    149             int red;
    150             scanf("%d",&red);
    151             costR[red]=0;
    152         }
    153         costR[1]=0;
    154         for(int i=1;i<n;i++)
    155         {
    156             int u,v,w;
    157             scanf("%d%d%d",&u,&v,&w);
    158             addEdge(u,v,w);
    159         }
    160         Solve();
    161     }
    162     return 0;
    163 }
    View Code

    •出现的问题

      1、用 vector 存储图比用 链式前向星存储图要慢

        (1)vector : 

        (2)链式前向星:

      2、平常一直在用的RMQ会超时

     1 //=====================RMQ===================
     2 struct Node1
     3 {
     4     int dp[20][2*maxn];
     5     void Preset()
     6     {
     7         for(int i=0;i < 2*maxn;++i)
     8             dp[0][i]=i;
     9     }
    10     void ST()
    11     {
    12         int k=log(total)/log(2);
    13         for(int i=1;i <= k;++i)
    14             for(int j=1;j <= (total-(1<<i)+1);++j)
    15                 if(depth[dp[i-1][j]] > depth[dp[i-1][j+(1<<(i-1))]])
    16                     dp[i][j]=dp[i-1][j+(1<<(i-1))];
    17                 else
    18                     dp[i][j]=dp[i-1][j];
    19     }
    20     int Lca(int u,int v)
    21     {
    22         u=pos[u],v=pos[v];
    23         if(u > v)
    24             swap(u,v);
    25         int k=log(v-u+1)/log(2);
    26         if(depth[dp[k][u]] > depth[dp[k][v-(1<<k)+1]])
    27             return vs[dp[k][v-(1<<k)+1]];
    28         return vs[dp[k][u]];
    29     }
    30 }_rmq;
    31 //===========================================
    TLE
     1 //==================RMQ======================
     2 struct Node2
     3 {
     4     int mm[2 * maxn];
     5     int dp[2 * maxn][20];
     6     void ST()
     7     {
     8         int n=total;
     9         mm[0] = -1;
    10         for (int i = 1; i <= n; i++)
    11         {
    12             mm[i]=((i&(i-1))==0) ? mm[i - 1] + 1:mm[i - 1];
    13             dp[i][0]=i;
    14         }
    15         for (int j=1;j <= mm[n];j++)
    16             for (int i=1;i+(1<<j)-1 <= n;i++)
    17                 if(depth[dp[i][j - 1]] < depth[dp[i+(1<<(j-1))][j-1]])
    18                     dp[i][j]=dp[i][j-1];
    19                 else
    20                     dp[i][j]=dp[i+(1<<(j-1))][j-1];
    21     }
    22     int Lca(int u, int v)
    23     {
    24         u=pos[u],v=pos[v];
    25         if (u > v)
    26             swap(u, v);
    27         int k = mm[v-u+1];
    28         if(depth[dp[u][k]] <= depth[dp[v-(1<<k)+1][k]])
    29             return vs[dp[u][k]];
    30         return vs[dp[v-(1<<k)+1][k]];
    31     }
    32 }_rmq;
    33 //==========================================
    AC

      3、cost[ ] 很有用,如果 Check( ) 中不加    

          if(cost[a[i]]-cost[lca] > x)
            return false;

        会返回 WA,具体为什么,明天再好好想想%%%%%%%%%

     


    分割线:2019.5.8

      中石油的这场重现赛又让我回想起了这道题留下的疑惑;

      现在再想想这道题,思路清晰了些许;

      一些不理解的地方瞬间顿悟了;

      ST表处理RMQ中,会多次求解 log2(x),这种算式是比较耗时的,我们预处理出所需的log2(x);

    logTwo[i]=log2(i);

      如何预处理呢?

      首先想一下,三位数的二进制数的最大值为 111(2),四位数的二进制数的最小值为 1000(2)

      两者的关系是 (111)&(1000) = 0 , 而对于任意三位二进制数 x,y ,(x&y) != 0;

      有了这个关系后,就可以这么预处理了:

    logTwo[0]=-1;
    for(int i=1;i <= n;++i)
        logTwo[i]=(i&(i-1)) == 0 ? logTwo[i-1]+1:logTwo[i-1];

      这就是之前一直不理解的ST表加速的地方;

    •Code

      1 #include<bits/stdc++.h>
      2 using namespace std;
      3 #define ll long long
      4 #define mem(a,b) memset(a,b,sizeof(a))
      5 #define INFll 0x3f3f3f3f3f3f3f3f
      6 const int maxn=1e5+50;
      7 
      8 int n,m,q;
      9 ll C[maxn];///C[i]:节点i到根节点1的花费
     10 ll CR[maxn];///CR[i]:节点i到其最近的红色祖先节点的花费
     11 int num;
     12 int head[maxn];
     13 struct Edge
     14 {
     15     int to;
     16     ll w;
     17     int next;
     18 }G[maxn<<1];
     19 void addEdge(int u,int v,ll w)
     20 {
     21     G[num]={v,w,head[u]};
     22     head[u]=num++;
     23 }
     24 struct LCA
     25 {
     26     int vs[maxn<<1];///欧拉序列
     27     int dep[maxn<<1];///欧拉序列中的节点对应的深度序列
     28     int pos[maxn<<1];///pos[i]:节点i在欧拉序列中第一次出现的位置
     29     int cnt;
     30     int logTwo[maxn<<1];///logTwo[i]:log2(i)
     31     int dp[maxn<<1][20];///dp[i][j]:[i,i+2^j-1]深度最小的点的下标(欧拉序列中的下标)
     32     void DFS(int u,int f,int depth,ll dist)
     33     {
     34         vs[++cnt]=u;
     35         dep[cnt]=depth;
     36         pos[u]=cnt;
     37         C[u]=dist;
     38         for(int i=head[u];~i;i=G[i].next)
     39         {
     40             int v=G[i].to;
     41             ll w=G[i].w;
     42             if(v == f)
     43                 continue;
     44             CR[v]=min(CR[v],CR[u]+w);
     45             DFS(v,u,depth+1,dist+w);
     46             vs[++cnt]=u;
     47             dep[cnt]=depth;
     48         }
     49     }
     50     void ST()
     51     {
     52         logTwo[0]=-1;
     53         for(int i=1;i <= cnt;++i)
     54         {
     55             dp[i][0]=i;
     56             ///:后的语句写错了,刚开始写成了logTwo[i],debug了好一会
     57             logTwo[i]=(i&(i-1)) == 0 ? logTwo[i-1]+1:logTwo[i-1];
     58         }
     59         for(int k=1;k <= logTwo[cnt];++k)
     60             for(int i=1;i+(1<<k)-1 <= cnt;++i)
     61                 if(dep[dp[i][k-1]] > dep[dp[i+(1<<(k-1))][k-1]])
     62                     dp[i][k]=dp[i+(1<<(k-1))][k-1];
     63                 else
     64                     dp[i][k]=dp[i][k-1];
     65     }
     66     void lcaInit(int root)
     67     {
     68         cnt=0;
     69         DFS(root,root,0,0);
     70         ST();
     71     }
     72     int lca(int u,int v)///返回节点u,v的LCA
     73     {
     74         u=pos[u];
     75         v=pos[v];
     76 
     77         if(u > v)
     78             swap(u,v);
     79 
     80         int k=logTwo[v-u+1];
     81         if(dep[dp[u][k]] > dep[dp[v-(1<<k)+1][k]])
     82             return vs[dp[v-(1<<k)+1][k]];
     83         else
     84             return vs[dp[u][k]];
     85     }
     86 }_lca;
     87 
     88 int qCnt;
     89 int query[maxn<<1];
     90 
     91 bool Check(ll mid)
     92 {
     93     int lca=0;///不满足条件的点的LCA
     94     for(int i=1;i <= qCnt;++i)
     95     {
     96         if(CR[query[i]] <= mid)
     97             continue;
     98         if(lca == 0)
     99             lca=query[i];
    100         else/// > mid的点LCA
    101             lca=_lca.lca(lca,query[i]);
    102     }
    103 
    104     for(int i=1;i <= qCnt;++i)
    105     {
    106         if(CR[query[i]] <= mid)
    107             continue;
    108 
    109         ///如果将lca点涂红后还不能使其 <= mid,返回false
    110         if(C[query[i]]-C[lca] > mid)
    111             return false;
    112     }
    113     return true;
    114 }
    115 void Solve()
    116 {
    117     _lca.lcaInit(1);
    118 
    119     for(int i=1;i <= q;++i)
    120     {
    121         scanf("%d",&qCnt);
    122 
    123         ll l=-1,r=0;
    124         for(int j=1;j <= qCnt;++j)
    125         {
    126             scanf("%d",query+j);
    127             r=max(r,CR[query[j]]);
    128         }
    129 
    130         while(r-l > 1)
    131         {
    132             ll mid=l+((r-l)>>1);
    133             if(Check(mid))
    134                 r=mid;
    135             else
    136                 l=mid;
    137         }
    138         printf("%lld
    ",r);
    139     }
    140 }
    141 void Init()
    142 {
    143     num=0;
    144     mem(head,-1);
    145     mem(CR,INFll);///初始化为最大值
    146 }
    147 int main()
    148 {
    149 //    freopen("C:\Users\hyacinthLJP\Desktop\in&&out\contest","r",stdin);
    150     int test;
    151     scanf("%d",&test);
    152     while(test--)
    153     {
    154         Init();
    155         scanf("%d%d%d",&n,&m,&q);
    156         for(int i=1;i <= m;++i)
    157         {
    158             int red;
    159             scanf("%d",&red);
    160             CR[red]=0;
    161         }
    162         CR[1]=0;
    163         for(int i=1;i < n;++i)
    164         {
    165             int u,v,w;
    166             scanf("%d%d%d",&u,&v,&w);
    167             addEdge(u,v,w);
    168             addEdge(v,u,w);
    169         }
    170         Solve();
    171     }
    172     return 0;
    173 }
    View Code
  • 相关阅读:
    今天,我们来聊一聊互联网真的有你所期待的那么好吗?来自一个老码农的碎碎念
    新鲜出炉!阿里Java后端面经,已拿offer!
    面试阿里,字节跳动,美团必被问到的红黑树原来这么简单
    凭借着这份Spring面试题,我拿到了阿里,字节跳动美团的offer!
    深度分析:理解Java中的多态机制,一篇直接帮你掌握!
    gdb调试core dump使用
    665. Non-decreasing Array
    netstat命令详解
    ifconfig命令
    #paragma详解
  • 原文地址:https://www.cnblogs.com/violet-acmer/p/9677889.html
Copyright © 2011-2022 走看看