zoukankan      html  css  js  c++  java
  • 【树形DP】codeforces K. Send the Fool Further! (medium)

    http://codeforces.com/contest/802/problem/K

    【题意】

    给定一棵树,Heidi从根结点0出发沿着边走,每个结点最多经过k次,求这棵树的最大花费是多少(同一条边走n次花费只算一次)

    【思路】

    对于结点v:

    • 如果在v的某棵子树停下,那么可以“遍历”k棵子树(有的话)
    • 如果还要沿着v返回v的父节点p,那么只能“遍历”k-1棵子树(有的话)。

    用dp[v][1]表示第一种情况,dp[v][0]表示第二种情况;最后要求的就是dp[0][0]。

    1. 对于dp[v][1],把所有的子树从大到小排序

    (t=k-1)

    2. 对于dp[v][0],枚举子结点dp[u][0]中的u,剩下的k-1个dp[u][1]取最大的,所以我们可以这样预处理:

    sum=

    (t=k)

    • 如果u<k,则target=sum-dp[u][1]+dp[u][0]
    • 否则,        target=sum-dp[t][1]+dp[u][0](t是从大到小排序后的第k-1个)

    这样,dp[0][0]就是所求结果(dp[0][0]一定大于dp[0][1]),时间复杂度是O(nlogn)

    【官方题解】

    【Accepted】

      1 #include<iostream>
      2 #include<cstdio>
      3 #include<cstring>
      4 #include<string>
      5 #include<cmath>
      6 #include<vector>
      7 #include<algorithm>
      8 
      9 using namespace std;
     10 int n,m;
     11 vector< vector< pair<int,int> > > g;
     12 const int maxn=1e5+5;
     13 int dp[maxn][2];
     14 void dfs(int v,int p,int edge)
     15 {
     16     //从p到v的花费要算在v里
     17     dp[v][0]+=edge;
     18     dp[v][1]+=edge;
     19     vector< pair<int,int> > s;
     20     //只有根结点没有父节点,非根结点有父节点,减去1
     21     if(v==0)
     22     {
     23         s.resize(g[v].size());
     24     }
     25     else
     26     {
     27         s.resize(g[v].size()-1);
     28     }
     29     //遍历
     30     int num=0;
     31     for(int i=0;i<g[v].size();i++)
     32     {
     33         int to=g[v][i].first;
     34         if(to==p)
     35         {
     36             continue;
     37          }
     38         dfs(to,v,g[v][i].second);
     39         s[num++]={dp[to][1],to};
     40     }
     41     //从大到小排序
     42     sort(s.begin(),s.end());
     43     reverse(s.begin(),s.end());
     44     //要记录各个子结点的rank,后面dp[v][0]枚举u是要分类
     45     int pos[maxn];
     46     for(int i=0;i<s.size();i++)
     47     {
     48         pos[s[i].second]=i;
     49     }
     50     //计算dp[v][1]
     51     for(int i=0;i<min(m-1,(int)s.size());i++)
     52     {
     53         dp[v][1]+=s[i].first;
     54     }
     55     //计算dp[v][0]
     56     int sum=0;
     57     for(int i=0;i<min(m,(int)s.size());i++)
     58     {
     59         sum+=s[i].first;
     60     }
     61     int maxu=-1;
     62     //枚举
     63     for(int i=0;i<g[v].size();i++)
     64     {
     65         int to=g[v][i].first;
     66         if(to==p)
     67         {
     68             continue;
     69         }
     70         if(pos[to]<m)
     71         {
     72             maxu=max(maxu,sum-dp[to][1]+dp[to][0]);
     73         }
     74         else
     75         {
     76             maxu=max(maxu,sum-s[m-1].first+dp[to][0]);
     77         }
     78     }
     79     if(maxu>-1)
     80     {
     81         dp[v][0]+=maxu;
     82     }
     83 }
     84 int main()
     85 {
     86     while(~scanf("%d%d",&n,&m))
     87     {
     88         memset(dp,0,sizeof(dp));
     89         g.resize(n);
     90         int u,v,c;
     91         for(int i=0;i<n-1;i++)
     92         {
     93             scanf("%d%d%d",&u,&v,&c);
     94             g[u].push_back({v,c});
     95             g[v].push_back({u,c});
     96         }
     97         //根结点为0,无父结点,根结点到父结点的花费也为0
     98         dfs(0,0,0);
     99         printf("%d
    ",dp[0][0]);
    100     }
    101     return 0;
    102  }
    View Code

    注意vector开始要resize.....orz

    【WA】

     1 #include<iostream>
     2 #include<cstdio>
     3 #include<cstring>
     4 #include<string>
     5 #include<algorithm>
     6 #include<cmath>
     7 
     8 using namespace std;
     9 int n,k;
    10 const int maxn=2e5+3; 
    11 struct edge
    12 {
    13     int to;
    14     int nxt;
    15     int c;    
    16 }e[maxn];
    17 int head[maxn];
    18 int tot;
    19 struct node
    20 {
    21     int x;
    22     int id;
    23 }sz[maxn];
    24 int rk[maxn];
    25 bool cmp(node a,node b)
    26 {
    27     return a.x>b.x;
    28 }
    29 void init()
    30 {
    31     memset(head,-1,sizeof(head));
    32     tot=0;
    33 }
    34 
    35 void add(int u,int v,int c)
    36 {
    37     e[tot].to=v;
    38     e[tot].c=c;
    39     e[tot].nxt=head[u];
    40     head[u]=tot++;
    41 }
    42 int dp[maxn][2];
    43 
    44 int dfs(int u,int pa,int c)
    45 {
    46     dp[u][1]=c;
    47     dp[u][0]=c;
    48     int cnt=0;
    49     for(int i=head[u];i!=-1;i=e[i].nxt)
    50     {
    51         int v=e[i].to;
    52         int c=e[i].c;
    53         if(v==pa) continue;
    54         dfs(v,u,c);
    55         sz[cnt].x=dp[v][1];
    56         sz[cnt++].id=v;
    57      } 
    58      sort(sz,sz+cnt,cmp);
    59      for(int i=0;i<min(cnt,k-1);i++)
    60      {
    61          dp[u][1]+=sz[i].x;
    62      }
    63      int sum=0;
    64      for(int i=0;i<min(cnt,k);i++)
    65      {
    66          sum+=sz[i].x;
    67      }
    68      int ans=0;
    69     for(int i=0;i<cnt;i++)
    70     {
    71         if(i<k)
    72         {
    73             ans=max(ans,sum-sz[i].x+dp[sz[i].id][0]);
    74         }
    75         else
    76         {
    77             ans=max(ans,sum-sz[k-1].x+dp[sz[i].id][0]);
    78         }
    79     }
    80      dp[u][0]+=ans;
    81 }
    82 int main()
    83 {
    84     while(~scanf("%d%d",&n,&k))
    85     {
    86         init();
    87         memset(dp,0,sizeof(dp));
    88         for(int i=0;i<n-1;i++)
    89         {
    90             int u,v,c;
    91             scanf("%d%d%d",&u,&v,&c);
    92             add(u,v,c);
    93             add(v,u,c);
    94         }
    95         dfs(0,-1,0);
    96         cout<<dp[0][0]<<endl; 
    97     }
    98     return 0;    
    99 } 
    Wrong Answer

    终于弄清楚了这个为什么WA!因为我在dfs里用了一个全局变量sz来保存{dp[v][1],v}。然而这是一个全局变量,所以一层里的正确值会被另一层修改!比如当我递归到0时已经有了正确值sz[0].w=5,sz[0].v=2;然而再递归到0的另一分枝1的时候,会修改sz[0],最后再回溯到0时sz[0]已经不是当年的sz[0]了!

    所以还是用vector临时申请吧!

    【AC(一个更优美的代码)】

      1 #include<iostream>
      2 #include<cstdio>
      3 #include<cstring>
      4 #include<string>
      5 #include<algorithm>
      6 #include<cmath>
      7 
      8 using namespace std;
      9 int n,k;
     10 const int maxn=2e5+3; 
     11 struct edge
     12 {
     13     int to;
     14     int nxt;
     15     int c;    
     16 }e[maxn];
     17 int head[maxn];
     18 int tot;
     19 int dp[maxn][2];
     20 
     21 struct node
     22 {
     23     int x;
     24     int id;
     25     node(){}
     26     node(int _x,int _id):x(_x),id(_id){}
     27     bool operator<(const node & nd) const
     28     {
     29         return x>nd.x;
     30     }
     31 };
     32 
     33 void init()
     34 {
     35     memset(head,-1,sizeof(head));
     36     tot=0;
     37 }
     38 
     39 void add(int u,int v,int c)
     40 {
     41     e[tot].to=v;
     42     e[tot].c=c;
     43     e[tot].nxt=head[u];
     44     head[u]=tot++;
     45 }
     46 
     47 int dfs(int u,int pa,int c)
     48 {
     49     dp[u][1]=c;
     50     dp[u][0]=c;
     51     vector<node> s;
     52     for(int i=head[u];i!=-1;i=e[i].nxt)
     53     {
     54         int v=e[i].to;
     55         int c=e[i].c;
     56         if(v==pa) continue;
     57         dfs(v,u,c);
     58         s.push_back(node(dp[v][1],v));
     59      } 
     60      sort(s.begin(),s.end());
     61      int sz=s.size();
     62      for(int i=0;i<min(sz,k-1);i++)
     63      {
     64          dp[u][1]+=s[i].x;
     65      }
     66      int sum=0;
     67      for(int i=0;i<min(sz,k);i++)
     68      {
     69          sum+=s[i].x;
     70      }
     71      int ans=0;
     72     for(int i=0;i<sz;i++)
     73     {
     74         if(i<k)
     75         {
     76             ans=max(ans,sum-s[i].x+dp[s[i].id][0]);
     77         }
     78         else
     79         {
     80             ans=max(ans,sum-s[k-1].x+dp[s[i].id][0]);
     81         }
     82     }
     83      dp[u][0]+=ans;
     84 }
     85 int main()
     86 {
     87     while(~scanf("%d%d",&n,&k))
     88     {
     89         init();
     90         memset(dp,0,sizeof(dp));
     91         for(int i=0;i<n-1;i++)
     92         {
     93             int u,v,c;
     94             scanf("%d%d%d",&u,&v,&c);
     95             add(u,v,c);
     96             add(v,u,c);
     97         }
     98         dfs(0,-1,0);
     99         cout<<dp[0][0]<<endl; 
    100     }
    101     return 0;    
    102 } 
    View Code

    如果是vector<pair<int,int>> 要从大到小排序,可以先sort(s.begin(),s.end()),再reverse(s.begin(),s.end())

  • 相关阅读:
    DRF项目创建流程(1)
    RESTful API规范
    超哥笔记--shell 基本命令(4)
    转:django模板标签{% for %}的使用(含forloop用法)
    自定django登录跳转限制
    jquery Ajax应用
    js重定向跳转页面
    django项目mysql中文编码问题
    python进阶(六) 虚拟环境git clone报错解决办法
    Linux基础(六) Vim之vundle插件
  • 原文地址:https://www.cnblogs.com/itcsl/p/6928544.html
Copyright © 2011-2022 走看看