题意看这篇博客:https://blog.csdn.net/dreaming__ldx/article/details/88418543
思路看这篇:https://blog.csdn.net/corsica6/article/details/88115948
有个坑点,不能深搜去找具体方案,不然 test14 会 MLE(或许是本蒟蒻写丑了)
代码:
#include <bits/stdc++.h> #define LL long long #define pii pair<int, int> using namespace std; const int maxn = 200010; int head[maxn], Next[maxn * 2], ver[maxn * 2], tot; int a[maxn]; LL dp[maxn][2], sum[maxn], f[maxn]; bool v[maxn][2]; bool res[maxn]; vector<int> ans; bool is_leaf[maxn]; void add(int x, int y) { ver[++tot] = y; Next[tot] = head[x]; head[x] = tot; } void dfs1(int x, int fa) { int cnt = 0; f[x] = fa; for (int i = head[x]; i; i = Next[i]) { int y = ver[i]; if(y == fa) continue; dfs1(y, x); sum[x] += dp[y][0]; cnt++; } if(cnt == 0) { dp[x][0] = a[x]; dp[x][1] = 0; is_leaf[x] = 1; return; } for (int i = head[x]; i; i = Next[i]) { int y = ver[i]; if(y == fa) continue; dp[x][0] = min(dp[x][0], sum[x] - dp[y][0] + dp[y][1] + a[x]); dp[x][1] = min(dp[x][1], sum[x] - dp[y][0] + dp[y][1]); } dp[x][0] = min(dp[x][0], sum[x]); } void bfs() { queue<pii> q; q.push(make_pair(1, 0)); while(!q.empty()) { pii tmp = q.front(); q.pop(); int x = tmp.first, flag = tmp.second; if(v[x][flag]) continue; v[x][flag] = 1; if(flag == 0) { int pos = -1, cnt = 0; if(is_leaf[x]) { res[x] = 1; continue; } for (int i = head[x]; i; i = Next[i]) { int y = ver[i]; if(y == f[x]) continue; if(dp[x][flag] == sum[x] - dp[y][0] + dp[y][1] + a[x]) { if(v[y][1]) continue; res[x] = 1; q.push(make_pair(y, 1)); pos = y; cnt++; } } for (int i = head[x]; i; i = Next[i]) { int y = ver[i]; if(v[y][0]) continue; if(y == f[x] || y == pos) continue; q.push(make_pair(y, 0)); } if(cnt > 1 || (sum[x] == dp[x][0] && pos != -1)) { if(v[pos][0]) continue; q.push(make_pair(pos, 0)); } } else { int pos = -1, cnt = 0; for (int i = head[x]; i; i = Next[i]) { int y = ver[i]; if(y == f[x]) continue; if(dp[x][flag] == sum[x] - dp[y][0] + dp[y][1]) { if(v[y][1]) continue; q.push(make_pair(y, 1)); pos = y; cnt++; } } for (int i = head[x]; i; i = Next[i]) { int y = ver[i]; if(v[y][0]) continue; if(y == f[x] || y == pos) continue; q.push(make_pair(y, 0)); } if(cnt > 1) { if(v[pos][0]) continue; q.push(make_pair(pos, 0)); } } } } int main() { int n, x, y; scanf("%d", &n); for (int i = 1; i <= n; i++) scanf("%d", &a[i]); for (int i = 1; i < n; i++) { scanf("%d%d", &x, &y); add(x, y); add(y, x); } memset(dp, 0x3f, sizeof(dp)); dfs1(1, -1); bfs(); for (int i = 1; i <= n; i++) { if(res[i]) ans.push_back(i); } printf("%lld %d ", dp[1][0], ans.size()); sort(ans.begin(), ans.end()); for (int i = 0; i < ans.size(); i++) printf("%d ", ans[i]); printf(" "); }
最小生成树解法先留个坑在这。。。
补坑:
思路看这篇博客:https://www.cnblogs.com/river-flows-in-you/p/10596821.html
说一下我个人的理解:把每个叶子节点看成新图的顶点,对于原树中的每个顶点,我们可以计算出它影响哪些叶子节点,用差分的思想连边。只要求出了生成树就说明可以任意取,因为形成生成树之后,我们对每个点的赋值操作就可以类比在生成树上遍历的过程。
代码:
#include <bits/stdc++.h> #define LL long long #define INF 0x3f3f3f3f using namespace std; const int maxn = 200010; int head[maxn], Next[maxn * 2], ver[maxn * 2], tot; int sz[maxn], cnt, l[maxn], r[maxn], a[maxn], f[maxn]; bool v[maxn]; struct Edge{ int u, v, w, id; bool operator < (const Edge& rhs) const { return w < rhs.w; } }; Edge b[maxn]; void add(int x, int y) { ver[++tot] = y; Next[tot] = head[x]; head[x] = tot; } void dfs(int x, int fa) { sz[x] = 1; l[x] = INF, r[x] = -1; for (int i = head[x]; i; i = Next[i]) { int y = ver[i]; if(y == fa) continue; dfs(y, x); sz[x] += sz[y]; l[x] = min(l[x], l[y]); r[x] = max(r[x], r[y]); } if(sz[x] == 1) { l[x] = r[x] = ++cnt; } b[x] = (Edge){l[x], r[x] + 1, a[x], x}; } int get(int x) { if(x == f[x]) return x; return f[x] = get(f[x]); } int main() { int n, x, y; scanf("%d", &n); int sum = 0; LL ans = 0; for (int i = 1; i <= n; i++) scanf("%d", &a[i]); for (int i = 1; i < n; i++) { scanf("%d%d", &x, &y); add(x, y); add(y, x); } dfs(1, -1); sort(b + 1, b + 1 + n); cnt++; for (int i = 1; i <= cnt; i++) f[i] = i; for(int L = 1, R; L <= n; L = R + 1) { R = L; while(b[L].w == b[R + 1].w && R < n) R++; for (int i = L; i <= R; i++) { x = get(b[i].u), y = get(b[i].v); if(x != y) { v[b[i].id] = 1; sum++; } } for (int i = L; i <= R; i++) { x = get(b[i].u), y = get(b[i].v); if(x != y) { ans += b[i].w; // printf("%lld %d ", ans, b[i].w); f[x] = y; } } } // printf("%lld %d ", ans, sum); cout << ans << " " << sum << endl; for (int i = 1; i <= n; i++) if(v[i]) printf("%d ", i); }