题意:
给定一棵树, 树的叶子节点是用户, 每个用户会提供一定的收益, 其他节点都是中间节点. 中间节点之间和中间节点到叶子节点上都有耗费. 求解在总耗费小于总收益的情况下所能得到的最多用户数
思路:
1. 树形DP, DP[i][j] 表示在以第 i 个节点为根的树上上保留 j 个用户的最大收益. 这样, 求解 dp[i][j] >=0 时最大的 j 即可. 这个 dp 的定义没想到
2. DP[i][j] = min(DP[i][j], DP[i][j-k]+DP[v][k]-c), c 表示节点 i 到节点 v 的耗费
总结:
1. 初始化时, dp[u][0] = 0, dp[u][else] = -INF. 初始化紧扣 DP 的定义, 符合逻辑
细节:
1. 初始化为 -INF, 求 max 时, 两个 -INF 相加可能会溢出, 因此进行了一步判断 if(dp[u][j-k] != -INFS && dp[v][k] != -INFS), 当然 ,不进行判断也是可以的, 这时要求 INF 不要设置的太大(1e9就不需要判断, 2*1e9 不会溢出), 这个技巧上一篇也讨论过
2. 下面代码中 j 的遍历若是从 N 开始, 则会 TLE. 从 N-M 开始会不会 LTE 没试过. 但是从 children[u] 开始倒是 AC 了. 计算 children 的时候, 我本打算在以 children += dfs(v) 的形式计算, 不过这似乎是不可能的, 只能单列一个方法, 单独 pass 一遍树. 另外, children[叶子] == 1 不合逻辑, 不过这倒是方便了计算
3. 还有一个技巧, 是对 dp[叶子][1] 的初始化, 减少了分支的判断, 少用了一个数组, 缺点是 dfs 函数的初始化乱了一点
for(int i = t+1; i <= N; i ++) {
cin >> dp[i][1];
}
代码:
#include <iostream> #include <vector> using namespace std; class node { public: int id, ct; node(int _id, int _ct):id(_id), ct(_ct){} node() { node(0,0); } }; const int INFS = 1E9; const int MAXN = 3200; int N, M, K; vector<node> graph[MAXN]; int cost[MAXN]; int dp[MAXN][MAXN]; int children[MAXN]; void dfs(int u) { for(int i = 2; i <= N; i ++) dp[u][i] = -INFS; dp[u][0] = 0; if(u > N-M) { return; } dp[u][1] = -INFS; for(int i = 0; i < graph[u].size(); i ++) { int v = graph[u][i].id; int c = graph[u][i].ct; dfs(v); for(int j = children[u]; j > 0; j --) for(int k = 0; k <= j; k ++) //if(dp[u][j-k] != -INFS && dp[v][k] != -INFS) // 防止溢出 dp[u][j] = max(dp[u][j], dp[u][j-k]+dp[v][k]-c); } } void getChildren(int u) { if(u > N-M) { children[u] = 1; return; } for(int i = 0; i < graph[u].size(); i ++) { int v = graph[u][i].id; getChildren(v); children[u] += children[v]; } } int main() { freopen("E:\Copy\ACM\poj\1155\in.txt", "r", stdin); while(cin >> N >> M) { int t = N-M; for(int i = 0; i < t; i ++) { cin >> K; int id, cost; for(int j = 0; j < K; j ++) { cin >> id >> cost; graph[i+1].push_back(node(id, cost)); } } for(int i = t+1; i <= N; i ++) { cin >> dp[i][1]; // 减少了很多分支判断条件 } getChildren(1); dfs(1); for(int i = N; i >= 0; i--) { if(dp[1][i] >= 0) { cout << i << endl; break; } } } return 0; }