题目链接:传送门
题目大意:
给出节点数为n的一棵带权树,和每个点的最大染色数k。一条边的权重w能产生价值w的条件是,这条边的两端的点至少有一个颜色相同。颜色种类数无限,但每种只能使用两次,问能产生的最大总价值。
思路:
(这两天刷dp专题ing,一上来就会朝dp方面想。看到题目中给出的还是个树,直接盲猜树形dp开始搓状态。结果还tm都搓出来了,以后再看到树会不会巴甫洛夫效应了呀!)
这题考虑树形dp。树形dp的话一般都是考虑一棵以u为根的树的状态,能不能很好地从以u的儿子v为根的子树的状态转移过来。
这题中想从v转移到u,只用考虑u和v有没有染相同颜色即可。
所以用f[u][0/1]表示以u为根的子树,u和u的父亲染不染相同颜色的条件下(0表示不染同色,1反之),产生的最大价值。
在f[u][0]中u最多能和k个儿子染相同的颜色,这些点产生的贡献是$sum_{u是v的父亲} f[v][1] + edge_{u, v}$,此外还要加上剩余的所有没考虑过的v的f[v][0]。
同理在f[u][1]中u最多能和k-1个儿子染相同的颜色。
考虑如何选择这k/k-1个儿子v:
他们如果连父亲,产生的贡献为f[v][1] + $edge_{u, v}$,如果不连父亲,产生的贡献是f[v][0]。连上父亲对总答案的贡献的增量为f[v][1] + $edge_{u,v}$ - f[v][0],根据这个增量排序从大到小取k/k-1个就可以了(注意一下可能有增量小于0的情况)
代码:O(nlogn)
#include <bits/stdc++.h> #define fast ios::sync_with_stdio(false), cin.tie(0), cout.tie(0) #define N 500005 #define M 500005 #define INF 0x3f3f3f3f #define mk(x) (1<<x) // be conscious if mask x exceeds int #define sz(x) ((int)x.size()) #define upperdiv(a,b) (a/b + (a%b>0)) #define mp(a,b) make_pair(a, b) #define endl ' ' #define lowbit(x) (x&-x) using namespace std; typedef long long ll; typedef double db; /** fast read **/ template <typename T> inline void read(T &x) { x = 0; T fg = 1; char ch = getchar(); while (!isdigit(ch)) { if (ch == '-') fg = -1; ch = getchar(); } while (isdigit(ch)) x = x*10+ch-'0', ch = getchar(); x = fg * x; } template <typename T, typename... Args> inline void read(T &x, Args &... args) { read(x), read(args...); } template <typename T> inline void write(T x) { int len = 0; char c[21]; if (x < 0) putchar('-'), x = -x; do{++len; c[len] = x%10 + '0';} while (x /= 10); for (int i = len; i >= 1; i--) putchar(c[i]); } template <typename T, typename... Args> inline void write(T x, Args ... args) { write(x), write(args...); } int n, k; struct Node{ int u; ll val; bool operator < (const Node& x) const { return val > x.val; } }; int tot = 0; int head[N], nxt[M<<1], ver[M<<1]; ll wei[M<<1]; void addEdge(int u, int v, ll w) { nxt[++tot] = head[u], ver[tot] = v, wei[tot] = w, head[u] = tot; } bool added[N][2]; ll f[N][2]; void dfs(int u, int p) { vector <Node> nodes; for (int i = head[u]; i != -1; i = nxt[i]) { int v = ver[i]; ll w = wei[i]; if (v == p) continue; dfs(v, u); nodes.push_back(Node{v, f[v][1]+w - f[v][0]}); } sort(nodes.begin(), nodes.end()); for (int i = 0; i < sz(nodes) && i < k; i++) { Node tmp = nodes[i]; if (tmp.val <= 0) break; if (i < k-1) { f[u][1] += tmp.val + f[tmp.u][0]; added[tmp.u][1] = true; } f[u][0] += tmp.val + f[tmp.u][0]; added[tmp.u][0] = true; } for (int i = head[u]; i != -1; i = nxt[i]) { int v = ver[i]; if (v == p) continue; if (!added[v][1]) f[u][1] += f[v][0]; if (!added[v][0]) f[u][0] += f[v][0]; } } int main() { int q; read(q); while (q--) { read(n, k); tot = 0; for (int i = 1; i <= n; i++) { head[i] = -1; added[i][0] = added[i][1] = false; f[i][0] = f[i][1] = 0; } for (int i = 1; i <= n-1; i++) { int u, v; ll w; read(u, v, w); addEdge(u, v, w); addEdge(v, u, w); } dfs(1, -1); ll ans = f[1][0]; cout << ans << endl; } return 0; }