Solution
我们发现要解决一个树上的连通块问题,解决这种问题的时候我们不妨先随便选一个根,如果要选某两个点则他们到n的路径上的点都会被选就变成了一个树形背包问题。
注意这里是多重背包,所以我们可以用单调队列优化,时间复杂度$O(N^2M)$。
考虑暴力选根的时候会把很多重复的情况算进去,所以我们可以用点分治,只计算根的孩子之间的贡献,递归子树时其余兄弟节点就不用管了。
因为每次选的是重心,所以子树大小必然减一半,时间复杂度$O(NMlog{N})$。
点分治+树形背包,这是一种常见的处理树上连通块的方法。
Code
因为加了单调队列优化,所以要注意树形背包时倒着做(即从叶节点开始)。
#include <iostream> #include <cstdio> #include <cstring> #include <algorithm> using namespace std; const int N = 5010, M = 40010, inf = 0x3f3f3f3f; struct node{ int pre, to; }edge[N << 1]; int head[N], tot; int T; int n, m; int sz[N], mx[N], rt; int w[N], v[N], d[N]; int dfn[N], bl[N], dep; int dp[N][M]; int ans; bool vis[N]; void get_root(int x, int tot_size, int fa) { sz[x] = 1; mx[x] = 0; for (int i = head[x]; i; i = edge[i].pre) { int y = edge[i].to; if (y == fa || vis[y]) continue; get_root(y, tot_size, x); sz[x] += sz[y]; if (mx[x] < sz[y]) { mx[x] = sz[y]; } } mx[x] = max(mx[x], tot_size - sz[x]); if (mx[x] < mx[rt]) { rt = x; } } void dfs(int x, int fa) { sz[x] = 1; dfn[++dep] = x; bl[x] = dep; for (int i = head[x]; i; i = edge[i].pre) { int y = edge[i].to; if (y == fa || vis[y]) continue; dfs(y, x); sz[x] += sz[y]; } } void cmax(int &x, int y) { x = max(x, y); } int q[N]; void solve(int x) { vis[x] = 1; dep = 0; dfs(x, 0); for (int i = 0; i <= dep + 1; i++) { for (int j = 0; j <= m; j++) { dp[i][j] = 0; } } for (int i = dep; i >= 1; i--) { for (int j = 0; j <= m; j++) { cmax(dp[i][j], dp[i + sz[dfn[i]]][j]); } int a = v[dfn[i]]; int b = w[dfn[i]]; int c = d[dfn[i]]; for (int j = 0; j < a; j++) { int heead = 1, tail = 0; for (int k = 0; j + k * a <= m; k++) { while (heead <= tail && q[heead] < k - c) heead++; if (heead <= tail) cmax(dp[i][j + k * a], dp[i + 1][j + q[heead] * a] - q[heead] * b + k * b); while (heead <= tail && dp[i + 1][j + a * q[tail]] - q[tail] * b <= dp[i + 1][j + a * k] - k * b) tail--; q[++tail] = k; } } } for (int i = 1; i <= m; i++) { ans = max(ans, dp[1][i]); } for (int i = head[x]; i; i = edge[i].pre) { int y = edge[i].to; if (vis[y]) continue; rt = 0; get_root(y, sz[y], x); solve(rt); } } int read() { int ret = 0, f = 1; char ch = getchar(); while (!isdigit(ch)) { if (ch == '-') f = -1; ch = getchar(); } while (isdigit(ch)) { ret = (ret << 1) + (ret << 3) + ch - '0'; ch = getchar(); } return ret * f; } void write(int x) { if (x > 9) write(x / 10); putchar(x % 10 + '0'); } void print(int x) { if (x < 0) { x = -x; putchar('-'); } write(x); putchar(' '); } void init() { ans = 0; memset(vis, 0, sizeof(vis)); memset(head, 0, sizeof(head)); tot = 0; } void add(int u, int vv) { edge[++tot] = node{head[u], vv}; head[u] = tot; } int main() { mx[0] = inf; T = read(); while (T--) { init(); n = read(); m = read(); for (int i = 1; i <= n; i++) { w[i] = read(); } for (int i = 1; i <= n; i++) { v[i] = read(); } for (int i = 1; i <= n; i++) { d[i] = read(); } for (int i = 1, u, vv; i < n; i++) { u = read(); vv = read(); add(u, vv); add(vv, u); } rt = 0; get_root(1, n, 0); solve(rt); print(ans); } return 0; }