首先,观察题意,可以发现在最长链下再接一个点,结果一定更优。
也就是说,可以免费选一条最长链,之后正常选。
我们枚举选的最长链,然后算出剩下部分的最优解。
有4部分:
1、链上每个点都选一个。
2、链上剩下的部分。
3、链的左面。
4、链的右面。
1可以直接计算。
那么,我们需要先进行树形背包,然后再通过某方式将其余3个合并。
我们知道,在此问题中,合并2个背包是(O(k))的;
但3个及以上则是(O(k^2))的,无法承受。
所以,我们只能在计算中就把其中两个合并,这样就只需合并2个了。
可以发现,3和4是正常的树形背包,而2是一个贪心的问题。
但是,我们没有时间给链上的点排序,再贪心选择。
所以,只能将2转为正常的树形背包问题。
可以这样想:选x则要选fx,选x的第二个则要选x的第一个。
那么,我们可以把大于1的拆点,拆成1和a-1。连上父子关系。
不难发现,这样我们就只需要3,4两个合并了。所以,只要算出3,4部分,就能(O(k))出解了。
先考虑如何进行树形背包:
树形背包有2种实现:dfs合并的和dfs序上dp的。
由于本题每个节点上有多个,所以之前的(O(nk))的分析不适用,复杂度是(O(nk^2)),显然超时。
而且,复杂度的瓶颈合并背包至少是(O(k^2))的,这个是max卷积,不能优化。
所以,这种方法不行。
但是,由于我们不需要知道每个节点的子节点的选择信息(如每个点选择了多少子树节点),
所以,可以考虑dfs序上dp的算法。见博客
这个算法的状态数是(O(nk))的,且相当于逐渐添加,没有合并背包。
添加的过程是一个多重背包(由于每个节点上有多个),可以用单调队列优化。这部分可以做到(O(nk))。
而且,我们发现:对于问题3,4(即树链的左右),在dfs序上是一段连续区间。这意味着我们可以直接得出3,4的dp值。
对树做先序遍历,可以得到链的右面(后缀)。对树做后遍历,可以得到链的左面(前缀)。
总结下:
首先,拆点。
然后,对树进行先后序遍历,并用多重背包的单调队列优化算出dp值,(O(nk))。
最后,枚举一个叶子,在(O(k))时间算出结果,总共(O(nk))。
此外,对于子节点,在3,4两部分都会被算到,要注意排除。
注意卡常。
代码:
#include <stdio.h>
#define inf 999999999
#define setdp(i, j, x) dp[i * (k + 1) + j] = x
#define getdp(i, j) dp[(i) * (k + 1) + j]
#define getod(i, j) ld[(i) * (k + 1) + j]
int fr[40010],ne[40010],v[40010],bs = 0,sl[40010],sz[40010],n,k;
void addb(int a, int b) {
v[bs] = b;
ne[bs] = fr[a];
fr[a] = bs++;
}
int xl[40010],si[40010],jl[40010],tm = 0,x1[40010],x2[40010];
void dfs1(int u) {
x1[u] = tm;
xl[tm++] = u;
si[u] = 1;
for (int i = fr[u]; i != -1; i = ne[i]) {
jl[v[i]] = jl[u] + sz[v[i]];
dfs1(v[i]);
si[u] += si[v[i]];
}
}
void dfs2(int u) {
for (int i = fr[u]; i != -1; i = ne[i]) dfs2(v[i]);
xl[++tm] = u;
x2[u] = tm;
}
int dl[500010],dz[500010],he = 0,ta = 0,dp[60000010],ld[60000010];
void insert(int i, int x) {
dz[i] = x;
while (he < ta && dz[dl[ta - 1]] <= x) ta -= 1;
dl[ta++] = i;
}
void del(int i) {
if (he < ta && dl[he] == i) he += 1;
}
int getma() {
if (he < ta) return dz[dl[he]];
else return - inf;
}
bool ez[20010],kz[20010];
int main() {
int T;
scanf("%d", &T);
while (T--) {
scanf("%d%d", &n, &k);
bs = 0;
for (int i = 1; i <= n + n; i++) fr[i] = -1;
for (int i = 1; i <= n; i++) ez[i] = kz[i] = false;
for (int i = 1; i <= n; i++) {
int a;
scanf("%d%d%d", &a, &sl[i], &sz[i]);
if (i > 1) addb(a, i);
ez[a] = true;
}
for (int i = 1; i <= n; i++) {
if (sl[i] > 1) {
sl[i + n] = sl[i] - 1;
sz[i + n] = sz[i];
addb(i, i + n);
sl[i] = 1;
kz[i] = true;
}
}
jl[1] = sz[1];
tm = 0;
dfs1(1);
for (int i = tm - 1; i >= 0; i--) {
he = ta = 0;
for (int j = 0; j <= k; j++) {
int u = xl[i],
ma = getdp(i + si[u], j);
del(j - sl[u] - 1);
if (j > 0) insert(j - 1, getdp(i + 1, j - 1) - sz[u] * (j - 1));
int t = getma() + sz[u] * j;
if (t > ma) ma = t;
setdp(i, j, ma);
}
}
for (int i = 0; i <= tm * (k + 1) + k; i++) {
ld[i] = dp[i];
dp[i] = 0;
}
tm = 0;
dfs2(1);
for (int i = 1; i <= tm; i++) {
he = ta = 0;
for (int j = 0; j <= k; j++) {
int u = xl[i],
ma = getdp(i - si[u], j);
del(j - sl[u] - 1);
if (j > 0) insert(j - 1, getdp(i - 1, j - 1) - sz[u] * (j - 1));
int t = getma() + sz[u] * j;
if (t > ma) ma = t;
setdp(i, j, ma);
}
}
int jg = -inf;
for (int i = 1; i <= n; i++) {
if (ez[i]) continue;
int ma = -inf;
for (int j = 0; j <= k; j++) {
int t = getod(x1[i] + 1, j);
if (!kz[i]) t += getdp(x2[i] - 1, k - j);
else t += getdp(x2[i + n] - 1, k - j);
if (t > ma) ma = t;
}
ma += jl[i];
if (ma > jg) jg = ma;
}
printf("%d
", jg);
for (int i = 0; i <= tm * (k + 1) + k; i++) ld[i] = dp[i] = 0;
}
return 0;
}