这篇题解介绍一下我自己想到的一个换根dp的科技。在这篇题解中也有提及。不知道以前有没有被提出过?
具体思路同UM的题解这里就不再赘述。
主要是换根的部分,我维护一个“重儿子”,即UM题解中的(len_u)是从哪个(v)来的。
这时在换根时若要将(x)换成(y),我们进行分类讨论:
1、(y)不是(x)的重儿子
这种情况在换完根后(x)的重儿子以及最长链不会改变,直接重新计算(y)即可。
2、(y)是(x)的重儿子
这时(x)的重儿子一定会变,我们重新遍历所有与(x)相连的节点(不包括(y)),重新计算(x)的重儿子与最长链。之后再计算(y)。
这样做的复杂度看起来是(O(n^2))其实不然,因为每个点至多会有一个重儿子,所以每个点在换根的时候至多只会重新遍历一遍与其相连的边,而每条边只与两个点相连所以每条边至多会被遍历两遍,一棵树又只有(n-1)条边,所以总的时间复杂度为(O(n))。
放一下我这么做的代码:
mx是最长链,son是重儿子,dp与UM的g一样,have记录子树中有多少人。
#include <iostream>
#include <cstdio>
using namespace std;
typedef long long ll;
const int N = 500010;
namespace IO{
template <typename T> void read(T &x) {
T f = 1;
char ch = getchar();
for (; !isdigit(ch); ch = getchar()) if (ch == '-') f = -1;
for (x = 0; isdigit(ch); ch = getchar()) x = x * 10 + ch - '0';
x *= f;
}
template <typename T> void write(T x) {
if (x > 9) write(x / 10);
putchar(x % 10 + '0');
}
template <typename T> void print(T x) {
if (x < 0) putchar('-'), x = -x;
write(x);
putchar('
');
}
} using namespace IO;
struct node{
int pre, to, val;
}edge[N << 1];
int head[N], tot;
int n, m;
int col[N], have[N], son[N];
ll mx[N], dp[N], ans[N];
void add(int u, int v, int l) {
edge[++tot] = node{head[u], v, l};
head[u] = tot;
}
void dfs1(int x, int fa) {
have[x] = col[x];
for (int i = head[x]; i; i = edge[i].pre) {
int y = edge[i].to;
if (y == fa) continue;
dfs1(y, x);
have[x] += have[y];
dp[x] += dp[y] + edge[i].val * 2 * (have[y] > 0);
if ((edge[i].val + mx[y]) * (have[y] > 0) > mx[x]) {
son[x] = y;
mx[x] = (edge[i].val + mx[y]) * (have[y] > 0);
}
}
}
void change_root(int x, int y, int val) {
have[x] -= have[y];
dp[x] -= dp[y] + val * 2 * (have[y] > 0);
have[y] += have[x];
dp[y] += dp[x] + val * 2 * (have[x] > 0);
if (y == son[x]) {
son[x] = mx[x] = 0;
for (int i = head[x]; i; i = edge[i].pre) {
if (edge[i].to == y) continue;
if ((edge[i].val + mx[edge[i].to]) * (have[edge[i].to] > 0) > mx[x]) {
son[x] = edge[i].to;
mx[x] = (edge[i].val + mx[edge[i].to]) * (have[edge[i].to] > 0);
}
}
}
if ((val + mx[x]) * (have[x] > 0) > mx[y]) {
son[y] = x;
mx[y] = (val + mx[x]) * (have[x] > 0);
}
}
void dfs2(int x, int fa) {
ans[x] = dp[x] - mx[x];
for (int i = head[x]; i; i = edge[i].pre) {
int y = edge[i].to;
if (y == fa) continue;
change_root(x, y, edge[i].val);
dfs2(y, x);
change_root(y, x, edge[i].val);
}
}
int main() {
read(n); read(m);
for (int i = 1, u, v, l; i < n; i++) {
read(u); read(v); read(l);
add(u, v, l); add(v, u, l);
}
for (int i = 1, u; i <= m; i++) {
read(u);
col[u] = 1;
}
dfs1(1, 0);
dfs2(1, 0);
for (int i = 1; i <= n; i++)
print(ans[i]);
return 0;
}