题意:带边权树上有白点和黑点,问你最多不经过k个黑点使得路径最长(注意,路径有负数)
解题思路:基于树的点分治。数的路径问题,具体看09QZC论文,特别注意 当根为黑时的情况
解题代码:
![](https://images.cnblogs.com/OutliningIndicators/ContractedBlock.gif)
1 // File Name: spoj1825.cpp 2 // Author: darkdream 3 // Created Time: 2014年10月05日 星期日 20时20分33秒 4 5 #include<vector> 6 #include<list> 7 #include<map> 8 #include<set> 9 #include<deque> 10 #include<stack> 11 #include<bitset> 12 #include<algorithm> 13 #include<functional> 14 #include<numeric> 15 #include<utility> 16 #include<sstream> 17 #include<iostream> 18 #include<iomanip> 19 #include<cstdio> 20 #include<cmath> 21 #include<cstdlib> 22 #include<cstring> 23 #include<ctime> 24 #define LL long long 25 #define maxn 200015 26 using namespace std; 27 struct node{ 28 int ne; 29 int w; 30 node(int _ne,int _w) 31 { 32 ne = _ne ; 33 w = _w; 34 } 35 }; 36 int n ,K, m ; 37 int col[maxn]; 38 int vis[maxn]; 39 vector <node> mp[maxn]; 40 int sum[maxn]; 41 int mx[maxn]; 42 int cnum[maxn]; 43 void getsize(int k,int la) 44 { 45 sum[k] = 1; 46 mx[k] = 0; 47 int num = mp[k].size(); 48 int tt = 0 ; 49 for(int i = 0 ;i < num;i ++) 50 { 51 if(!vis[mp[k][i].ne] && mp[k][i].ne != la) 52 { 53 getsize(mp[k][i].ne,k); 54 mx[k] = max(sum[mp[k][i].ne],mx[k]); 55 sum[k] += sum[mp[k][i].ne]; 56 } 57 } 58 } 59 int root; 60 int mxv; 61 int getroot(int k,int la ,int tans) 62 { 63 int tt = max(tans - sum[k],mx[k]); 64 if(tt < mxv) 65 { 66 mxv = tt; 67 root = k ; 68 } 69 int num = mp[k].size(); 70 for(int i = 0 ;i < num ;i ++) 71 { 72 if(!vis[mp[k][i].ne] && mp[k][i].ne != la) 73 { 74 getroot(mp[k][i].ne,k,tans); 75 } 76 } 77 } 78 LL ans = 0 ; 79 LL dp[maxn]; 80 LL tdp[maxn]; 81 bool cmp(node a, node b) 82 { 83 return cnum[a.ne] < cnum[b.ne]; 84 } 85 void getdep(int k ,int la,int tc,LL dep) 86 { 87 int st = (col[k] == 1?1:0) ; 88 tdp[tc+st] = max(tdp[tc+st],dep); //这个点是G点的时候 89 int num = mp[k].size(); 90 for(int i = 0 ;i < num ;i ++) 91 { 92 if(!vis[mp[k][i].ne] && mp[k][i].ne != la ) 93 { 94 getdep(mp[k][i].ne,k,tc + st,dep + mp[k][i].w); 95 } 96 } 97 } 98 void getcnum(int k ,int la) 99 { 100 if(col[k]) 101 cnum[k] = 1; 102 else cnum[k] = 0 ; 103 int tt = 0 ; 104 int num = mp[k].size(); 105 for(int i = 0 ;i < num;i ++) 106 { 107 if(!vis[mp[k][i].ne] && mp[k][i].ne != la) 108 { 109 getcnum(mp[k][i].ne,k); 110 if(cnum[mp[k][i].ne] > tt) 111 tt = cnum[mp[k][i].ne]; 112 } 113 } 114 cnum[k] += tt; 115 } 116 void solve(int k) 117 { 118 getsize(k,0); 119 mxv = 1e9; 120 getroot(k,0,sum[k]); 121 k = root; 122 123 getcnum(k,0); 124 //printf("*****%d %d ",k,cnum[k]); 125 int num = mp[k].size(); 126 memset(dp,0,(cnum[k]+3)*sizeof(LL)); 127 int tk ; 128 int st = 0 ; 129 if(col[k]) 130 { 131 tk = K + 1; 132 st = 1; 133 } 134 else tk = K ; 135 int la =0 ; 136 //int size = min(cnum[k],K); 137 sort(mp[k].begin(),mp[k].end(),cmp); 138 for(int i = 0 ;i < num ;i ++) 139 { 140 if(vis[mp[k][i].ne]) 141 continue; 142 143 memset(tdp,0,(cnum[mp[k][i].ne]+3)*sizeof(tdp[0])); 144 if(col[k]) 145 getdep(mp[k][i].ne,k,1,mp[k][i].w); 146 else 147 getdep(mp[k][i].ne,k,0,mp[k][i].w); 148 // printf("**********%d ",tk); 149 150 151 int tt = min(cnum[mp[k][i].ne]+st,K); 152 // printf("%d %d ",cnum[mp[k][i].ne]+st,K); 153 for(int j = st ;j <= tt;j ++) 154 { 155 if(tk - j <= la) 156 { 157 if(tdp[j] + dp[tk-j]> ans) 158 { 159 ans = tdp[j] + dp[tk-j]; 160 } 161 }else{ 162 if(tdp[j] + dp[la]> ans) 163 { 164 ans = tdp[j] + dp[la]; 165 } 166 } 167 } 168 dp[0] = max(dp[0],tdp[0]); 169 //printf("%d %d ",n,cnum[mp[k][i].ne]); 170 /*if(tdp[tt+st+1] != 0) 171 { 172 printf("&&&&&&&&&&&&&& "); 173 }*/ 174 for(int j = 1 ;j <= tt+st; j ++) 175 { 176 dp[j] = max(dp[j],tdp[j]); 177 dp[j] = max(dp[j],dp[j-1]); 178 } 179 // for(int j = 0;j <= K;j ++) 180 // printf("%lld ",dp[j]); 181 // puts(""); 182 la = tt + st; 183 } 184 //puts("**********8"); 185 vis[k] = 1; 186 for(int i = 0;i < num;i ++) 187 { 188 if(!vis[mp[k][i].ne]) 189 solve(mp[k][i].ne); 190 } 191 return; 192 } 193 int main(){ 194 //freopen("out","r",stdin); 195 //freopen("output.txt","w",stdin); 196 while(scanf("%d %d %d",&n,&K,&m) != EOF){ 197 int temp ; 198 memset(vis,0,sizeof(vis)); 199 memset(col,0,sizeof(col)); 200 for(int i = 1;i <= n;i ++) 201 mp[i].clear(); 202 for(int i = 1;i <= m;i ++) 203 { 204 scanf("%d",&temp); 205 col[temp] = 1; 206 } 207 for(int i = 1;i <= n - 1;i ++) 208 { 209 int a, b , w; 210 scanf("%d %d %d",&a,&b,&w); 211 mp[a].push_back(node(b,w)); 212 mp[b].push_back(node(a,w)); 213 } 214 ans = 0; 215 solve(1); 216 printf("%lld ",ans); 217 } 218 return 0; 219 }