明显虚树。
别的题解里都是这样说的。
先不考虑虚树,假设只有一组询问,该如何dp?
f[u]表示把子树u中所有的有资源的节点都切掉的最优解
如果节点u需要切掉了话,$f[u]=val[u]$
否则如果u的子树中有需要切除的点的话,$f[u] = min(val[u], sumlimits_{v是u的儿子}f[v])$
val[u]表示是根到u的路径上最小的边的权值。
最后转移到虚树上即可。
#include <cstdio> #include <cstring> #include <iostream> #include <algorithm> #define N 1000000 #define LL long long using namespace std; int n, m, cnt, rp, top, T; int head[N], to[N], nex[N], dfn[N], f[N][21], q[N], deep[N], s[N]; LL ans[N], dp[N], val[N]; bool flag[N]; inline int read() { int x = 0, f = 1; char ch = getchar(); for(; !isdigit(ch); ch = getchar()) if(ch == '-') f = -1; for(; isdigit(ch); ch = getchar()) x = (x << 1) + (x << 3) + ch - '0'; return x * f; } inline void add(int x, int y, int z) { to[cnt] = y; val[cnt] = z; nex[cnt] = head[x]; head[x] = cnt++; } inline void dfs1(int u) { int i, v; dfn[u] = ++rp; deep[u] = deep[f[u][0]] + 1; for(i = 0; f[u][i]; i++) f[u][i + 1] = f[f[u][i]][i]; for(i = head[u]; ~i; i = nex[i]) { v = to[i]; if(!dfn[v]) { f[v][0] = u; dp[v] = min(dp[u], val[i]); dfs1(v); } } head[u] = -1; } inline int calc_lca(int x, int y) { int i, j; if(deep[x] < deep[y]) swap(x, y); for(i = 20; i >= 0; i--) if(deep[f[x][i]] >= deep[y]) x = f[x][i]; if(x == y) return x; for(i = 20; i >= 0; i--) if(f[x][i] != f[y][i]) x = f[x][i], y = f[y][i]; return f[x][0]; } inline bool cmp(int x, int y) { return dfn[x] < dfn[y]; } inline void dfs2(int u) { LL sum = 0; int i, v; ans[u] = dp[u]; for(i = head[u]; ~i; i = nex[i]) { v = to[i]; dfs2(v); sum += ans[v]; } if(sum && !flag[u]) ans[u] = min(ans[u], sum); head[u] = -1; } inline void solve() { int i, lca; m = read(); top = cnt = 0; for(i = 1; i <= m; i++) q[i] = read(), flag[q[i]] = 1; sort(q + 1, q + m + 1, cmp); for(i = 1; i <= m; i++) { if(!top) { s[++top] = q[i]; continue; } lca = calc_lca(q[i], s[top]); while(dfn[lca] < dfn[s[top]]) { if(dfn[lca] >= dfn[s[top - 1]]) { add(lca, s[top], 0); if(s[--top] != lca) s[++top] = lca; break; } add(s[top - 1], s[top], 0), top--; } s[++top] = q[i]; } while(top > 1) add(s[top - 1], s[top], 0), top--; dfs2(s[1]); printf("%lld ", ans[s[1]]); for(i = 1; i <= m; i++) flag[q[i]] = 0; } int main() { int i, x, y, z; n = read(); memset(head, -1, sizeof(head)); for(i = 1; i < n; i++) { x = read(); y = read(); z = read(); add(x, y, z); add(y, x, z); } dp[1] = 1ll * 1e9 * 1e9; dfs1(1); T = read(); while(T--) solve(); return 0; }