简介
dsu on tree跟dsu没有关系,但是dsu on tree借鉴了dsu的启发式合并的思想。
它是用来解决一类树上的询问问题,一般这种问题有以下特征:
(1.)只有对子树的查询;
(2.)没有修改。
如果满足以上特征,那么dsu on tree很可能就可以派上用场了。
算法
我们以CF600E Lomsat gelral为例。
Descrption
一棵以(1)为根的树有 (n) 个结点,每个结点的颜色是(c_i),每个颜色有一个编号,求树中每个子树的最多的颜色编号的和。
数据范围(1le nle 10^5, 1le c_ile 10^5)。
Solution
我们考虑暴力怎么写:遍历每一个节点 (u) ,然后把子树内的所有颜色暴力统计出来更新答案。
然后消除节点 (u) 的贡献,继续递归计算其他点的贡献。
复杂度(O(n^2)),显然是很不优美的。
然后dsu on tree就登场了。它也是一个暴力,但是它结合了轻重链剖分,将复杂度降到了(O(nlogn))。
我们先跑一个(dfs1),记录节点(u)的重儿子(heavy[u])。
接下来:
- 遍历每一个节点。
- 递归所有的轻儿子,递归进入时暴力加贡献,递归结束时暴力消除贡献。
- 递归重儿子,递归进入时暴力加贡献,递归结束时不消除贡献。
- 统计所有轻儿子的答案,并更新该节点的答案。
- 删除所有轻儿子对答案的影响。
复杂度(O(nlogn)),可以通过本题。
Code
// Author: wlzhouzhuan
#pragma GCC optimize(2)
#pragma GCC optimize(3)
#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define ull unsigned long long
#define rint register int
#define rep(i, l, r) for (rint i = l; i <= r; i++)
#define per(i, l, r) for (rint i = l; i >= r; i--)
#define mset(s, _) memset(s, _, sizeof(s))
#define pb push_back
#define pii pair <int, int>
#define mp(a, b) make_pair(a, b)
inline int read() {
int x = 0, neg = 1; char op = getchar();
while (!isdigit(op)) { if (op == '-') neg = -1; op = getchar(); }
while (isdigit(op)) { x = 10 * x + op - '0'; op = getchar(); }
return neg * x;
}
inline void print(int x) {
if (x < 0) { putchar('-'); x = -x; }
if (x >= 10) print(x / 10);
putchar(x % 10 + '0');
}
const int N = 100005;
int n;
vector <int> adj[N];
void add(int u, int v) { adj[u].pb(v); }
int sz[N], fa[N], heavy[N]; // heavy[u] 表示u的重儿子
void dfs1(int u, int f) {
sz[u] = 1, fa[u] = f;
for (auto v: adj[u]) {
if (v == f) continue;
dfs1(v, u);
sz[u] += sz[v];
if (sz[v] > sz[heavy[u]]) {
heavy[u] = v; // 更新重儿子
}
}
}
long long ans[N], sum;
int cnt[N], col[N], Son, Max;
void change(int u, int val) {
cnt[col[u]] += val;
if (cnt[col[u]] > Max) Max = cnt[col[u]], sum = col[u];
else if (cnt[col[u]] == Max) sum += col[u];
for (auto v: adj[u]) {
// 由于重儿子的信息没有被删去,所以已经统计过了,不能再计算
if (v == fa[u] || v == Son) continue;
change(v, val);
}
}
void dfs2(int u, int keep) {
for (auto v: adj[u]) {
if (v == fa[u] || v == heavy[u]) continue;
dfs2(v, 0); // 遍历u的轻儿子
}
if (heavy[u]) dfs2(heavy[u], 1), Son = heavy[u];
change(u, 1), ans[u] = sum, Son = 0;
if (!keep) change(u, -1), sum = Max = 0;
}
int main() {
n = read();
for (rint i = 1; i <= n; i++) {
col[i] = read();
}
for (rint i = 1; i < n; i++) {
int x = read(), y = read();
add(x, y), add(y, x);
}
int root = 1; // 题目默认1为根
dfs1(root, 0);
dfs2(root, 0);
for (rint i = 1; i <= n; i++) {
printf("%lld ", ans[i]);
}
return 0;
}