每一次枚举到重心 按子树中的黑点数SORT一下 启发式合并
#include<cstdio> #include<cstring> #include<algorithm> using namespace std; const int MAXN = 2e6 + 5; const int MAXM = 2e6 + 5; int to[MAXM << 1], nxt[MAXM << 1], Head[MAXN], ed = 1; int cost[MAXM << 1]; const int INF = ~0u >> 1; inline void addedge(int u, int v, int c) { to[++ed] = v; cost[ed] = c; nxt[ed] = Head[u]; Head[u] = ed; } inline void ADD(int u, int v, int c) { addedge(u, v, c); addedge(v, u, c); } inline const int readin() { int r = 0, k = 1; char c = getchar(); for (; c < '0' || c > '9'; c = getchar()) if (c == '-') { k = -1; } for (; c >= '0' && c <= '9'; c = getchar()) { r = r * 10 + c - '0'; } return k * r; } int n, k, kk, m, anser, cnt, maxdep, summaxdep; int sz[MAXN], f[MAXN], dep[MAXN], sumsz, root; bool vis[MAXN]; int ok[MAXN], blasz[MAXN]; int h[MAXN], g[MAXN]; struct node { int blaval; int id; } o[MAXN]; bool cmp(node a, node b) { return a.blaval < b.blaval; } void getroot(int x, int fa) { sz[x] = 1; f[x] = 0; for (int i = Head[x]; i; i = nxt[i]) { int v = to[i]; if (v == fa || vis[v]) { continue; } getroot(v, x); sz[x] += sz[v]; f[x] = max(f[x], sz[v]); } f[x] = max(f[x], sumsz - sz[x]); if (f[x] < f[root]) { root = x; } } void update(int x, int blanum, int deep, int fa) { if (blanum > kk) { return ; } h[blanum] = max(h[blanum], deep); for (int i = Head[x]; i; i = nxt[i]) { int v = to[i]; if (vis[v] || v == fa) { continue; } update(v, blanum + ok[v], deep + cost[i], x); } } void getdeep(int x, int fa) { blasz[x] = ok[x]; for (int i = Head[x]; i; i = nxt[i]) { int v = to[i]; if (v == fa || vis[v]) { continue; } getdeep(v, x); blasz[x] += blasz[v]; } } void calc(int x, int d) { cnt = 0; for (int i = Head[x]; i; i = nxt[i]) { int v = to[i]; if (vis[v]) { continue; } getdeep(v, x); node now; now.blaval = blasz[v]; now.id = i; o[++cnt] = now; } } void solve(int x) { summaxdep = -1; kk = k - ok[x]; int s; vis[x] = 1; calc(x, 0); sort(o + 1, o + cnt + 1, cmp); for (int i = 1; i <= cnt; i++) { maxdep = -1; int depnow = o[i].blaval; int v = to[o[i].id]; int c = cost[o[i].id]; s = min(depnow, kk); for (int j = 0; j <= s; j++) { h[j] = -INF; } update(v, ok[v], c, x); if (i == 1) { for (int j = 0; j <= s; j++) { g[j] = h[j]; } } else { for (int j = 0; j <= s; j++) { int aim = kk - j; aim = min(aim, summaxdep); if (h[j] != -INF && g[aim] != -INF) { anser = max(anser, h[j] + g[aim]); } } for (int j = 0; j <= s; j++) { g[j] = max(h[j], g[j]); } } summaxdep = s; for (int j = 1; j <= summaxdep; j++) { g[j] = max(g[j], g[j - 1]); } } anser = max(anser, g[min(kk, summaxdep)]); int totsz = sumsz; for (int i = Head[x]; i; i = nxt[i]) { int v = to[i]; if (vis[v]) { continue; } root = 0; sumsz = sz[v] > sz[x] ? totsz - sz[x] : sz[v]; getroot(v, 0); solve(root); } } int main() { cnt = anser = 0; n = readin(), k = readin(), m = readin(); for (int now, i = 1; i <= m; i++) { now = readin(); ok[now] = 1; } int u, v, c; for (int i = 1; i < n; i++) { u = readin(), v = readin(), c = readin(); ADD(u, v, c); } root = 0, sumsz = f[0] = n; getroot(1, 0); solve(root); printf("%d ", anser); return 0; }