题目大意:
在给定带权值节点的树上从1开始不回头走到某个底端点后得到所有经过的点的权值后,这些点权值修改为0,到达底部后重新回到1,继续走,问走k次,最多能得到多少权值之和
这其实就是相当于每一次都走权值最大的那一条路径,进行贪心k次
首先先来想想树链剖分的时候的思想:
重儿子表示这个儿子对应的子树的节点数最多,那么每次访问都优先访问重儿子
这道题里面我们进行一下转化,如果当前儿子能走出一条最长的路径,我们就令其为重儿子,那么很容易想到,到达父亲时,如果选择重儿子,那么之前到达
父亲所得的权值一定是记录在重儿子这条路径上的,那么访问轻儿子的时候,因为前面的值在到达重儿子后修改为0,所以走到轻儿子之前权值和修改为0
我们将所有到达底端点的路径长度保存到rec数组中,将rec排序取前k个即可,如果不够取,相当于全部取完,因为后面再走也就是相当于0,不必计算
1 #include <iostream> 2 #include <cstdio> 3 #include <cstring> 4 #include <algorithm> 5 #include <queue> 6 using namespace std; 7 #define N 100100 8 #define ll long long 9 10 int first[N] , k , t; 11 //t记录底层节点的个数,rec[i]记录到达i节点的那个时候经过的长度 12 ll rec[N]; 13 14 struct Edge{ 15 int y , next; 16 Edge(int y=0 , int next=0):y(y),next(next){} 17 }e[N]; 18 19 void add_edge(int x , int y) 20 { 21 e[k] = Edge(y , first[x]); 22 first[x] = k++; 23 } 24 25 bool cmp(ll a , ll b) 26 { 27 return a>b; 28 } 29 30 ll val[N] , down[N];//down[i]记录从i开始往下能走到的最长的路径 31 int heavyson[N]; 32 void dfs(int u) 33 { 34 ll maxn = -1; 35 for(int i=first[u] ; ~i ; i=e[i].next) 36 { 37 int v = e[i].y; 38 dfs(v); 39 if(maxn<down[v]){ 40 heavyson[u] = v; 41 maxn = down[v]; 42 } 43 } 44 if(maxn>=0) down[u] = maxn+val[u]; 45 else down[u] = val[u]; 46 } 47 48 void dfs1(int u , ll cur) 49 { 50 bool flag = true; //判断是否为底层节点 51 if(heavyson[u]){ 52 dfs1(heavyson[u] , cur+val[heavyson[u]]); 53 flag = false; 54 } 55 for(int i=first[u] ; ~i ; i=e[i].next) 56 { 57 int v = e[i].y; 58 if(v == heavyson[u]) continue; 59 dfs1(v , val[v]); 60 flag = false; 61 } 62 if(flag) rec[t++] = cur; 63 } 64 65 int main() 66 { 67 #ifndef ONLINE_JUDGE 68 freopen("a.in" , "r" , stdin); 69 #endif 70 int T , cas=0; 71 scanf("%d" , &T); 72 while(T--) 73 { 74 printf("Case #%d: " , ++cas); 75 int n,m; 76 scanf("%d%d" , &n , &m); 77 memset(first , -1 , sizeof(first)); 78 k=0; 79 for(int i=1 ; i<=n ; i++) scanf("%I64d" , val+i); 80 for(int i=1 ; i<n ; i++){ 81 int u,v; 82 scanf("%d%d" , &u , &v); 83 add_edge(u , v); 84 } 85 memset(heavyson , 0 , sizeof(heavyson)); 86 dfs(1); 87 t=0; 88 dfs1(1 , val[1]); 89 sort(rec , rec+t , cmp); 90 ll ret = 0; 91 for(int i=0 ; i<t ; i++){ 92 if(i==m) break; 93 ret+=rec[i]; 94 } 95 printf("%I64d " , ret); 96 } 97 return 0; 98 }