( ext{dsu on tree}) 略解
简介
首先 ( ext{dsu on tree}) 和并查集并没有关系,其用来处理一类树上问题,一般有两个特征:
- 不带修改
- 询问与子树有关
( ext{dsu on tree}) 可以十分方便的在 (O(nlogn)) 的时间复杂度内解决。
大致思路
( ext{dsu on tree}) 利用了重链剖分中重儿子的思想来进行暴力。
例如一道例题 CF600E:求每个子树内出现次数最多的颜色之和
(O(n^2)) 暴力十分显然,但是可以发现一个性质:父子之间的信息共享,而兄弟之间的信息不共享,也就是计算完最后一个子树的信息后,可以不用清空,其信息可以保留下来继续给父亲使用。
所以我们想到使最后一个遍历的子树尽可能大,也就是 重儿子。
算法流程
设当前求到 (u) 的答案 (ans_u),算法大致分为 (5) 步:
- 计算 轻儿子 (v) 的 (ans_v)
- 计算 (u) 重儿子 (son_u) 的 (ans_{son_u}),并将 (son_u) 的信息保留继续使用
- 再暴力计算每个轻儿子的信息
- 更新 (ans_u)
- 如果 (u) 不为重儿子,则暴力删去 (u) 的信息
”暴力计算 (v)“ 指将以 (v) 为根的子树遍历一遍计算信息(也可能因题目而异吧)
复杂度
首先有一个重要的性质:一个节点到根路径上的轻边数不超过 (logn),证明:
由轻重儿子的性质可知:对于 (u) 的任意轻儿子 (v) 有 (siz_v leq frac{siz_u}{2})
因此每经过一条轻边 (siz/2),那么任意点开始往叶子节点走经过轻边数量最多不超过 (logn) 条
得证
再考虑每个点 (v) 会被计算多少次,按其到根的路径上的轻/重边分为两类讨论:
- 对于每条轻边,都需要单独计算一次 (v) 的信息,由以上性质知不超过 (logn) 次
- 对于 (v) 到根路径上的每条重边,是不需要再计算 (v) 的
所以对于节点 (v),一共会被计算 (logn + 1) 次((1) 为计算 (ans_v))
综上,若计算一个点的信息为 (O(1)),则该算法时间复杂度为 (O(nlogn))。
例题
CF600E
第一次打 (Code) 有点丑......
Code
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
#define N 100000
#define fo(i, x, y) for(int i = x; i <= y; i ++)
#define Fo(i, u) for(int i = head[u]; i; i = edge[i].next)
#define ll long long
void read(int &x) {
char ch = getchar(); x = 0;
while (ch < '0' || ch > '9') ch = getchar();
while (ch >= '0' && ch <= '9') x = (x << 1) + (x << 3) + ch - 48, ch = getchar();
}
struct EDGE { int next, to; } edge[N << 1];
int head[N + 1], col[N + 1], h[N + 1], sz[N + 1], son[N + 1], las[N + 1];
ll ans[N + 1];
int n;
int cnt_edge = 1;
void Add(int u, int v) { edge[ ++ cnt_edge ] = (EDGE) { head[u], v }, head[u] = cnt_edge; }
void Link(int u, int v) { Add(u, v), Add(v, u); }
void Dfs1(int u, int la) {
sz[u] = 1, son[u] = 0;
Fo(i, u) if (i != la) {
Dfs1(edge[i].to, i ^ 1);
if (sz[edge[i].to] > sz[son[u]])
son[u] = edge[i].to;
sz[u] += sz[edge[i].to];
}
}
int max_h = 0;
ll sum = 0;
void Add1(int c, int d) {
h[c] += d;
if (h[c] > max_h) max_h = h[c], sum = c;
else if (h[c] == max_h) sum += c;
}
void Dfs3(int u, int la, int d) {
Add1(col[u], d);
Fo(i, u) if (i != la)
Dfs3(edge[i].to, i ^ 1, d);
}
void Dfs2(int u, int fa, int opt) {
int v = 0;
Fo(i, u) if ((v = edge[i].to) != fa && v != son[u])
Dfs2(v, u, 1);
if (son[u]) Dfs2(son[u], u, 0);
Fo(i, u) if ((v = edge[i].to) != fa && v != son[u])
Dfs3(v, i ^ 1, 1);
Add1(col[u], 1);
ans[u] = sum;
if (opt) {
Fo(i, u) if ((v = edge[i].to) != fa)
Dfs3(v, i ^ 1, -1);
Add1(col[u], -1);
sum = max_h = 0;
}
}
int main() {
read(n);
fo(i, 1, n) read(col[i]);
for (int i = 1, x, y; i < n; i ++)
read(x), read(y), Link(x, y);
Dfs1(1, 0);
Dfs2(1, 0, 0);
fo(i, 1, n) printf("%lld ", ans[i]);
return 0;
}
CF741D
Solution
由回文串的性质可知:区间内之多只有一个字符出现奇数次。
借此可以将统计出现次数转化为异或,可以用大小为 (2^{22}) 的状态表示从根开始的路径上每个字符出现次数的奇偶性。
设 (dis_{u}) 为从根到 (x) 的路径上字符的状态,那么任意路径 ((u, v)) 的字符状态就可以表示为 (dis_{(u, v)} = dis_u oplus dis_v oplus dis_{lca} oplus dis_{lca}),由异或的性质可知即为 (dis_u oplus dis_v),而距离就是 (dep_u + dep_v - 2dep_{lca})。
所以只需要用大小 (2^{22}) 的桶存下每个状态的最深深度,(O(22)) 可以求出一个点对答案的贡献,每个点 (u) 的 (ans_u) 为其子树 (ans) 与经过 (u) 最长符合条件路径的最大值。
剩下的就基本是 ( ext{dsu on tree}) 的模板了。
Code
#include <cstdio>
using namespace std;
#define N 500000
#define M 22
#define inf 10000
#define fo(i, x, y) for(int i = x, end_##i = y; i <= end_##i; i ++)
#define fd(i, x, y) for(int i = x, end_##i = y; i >= end_##i; i --)
#define Fo(i, u) for(int i = head[u]; i; i = edge[i].next)
void read(int &x) {
char ch = getchar(); x = 0;
while (ch < '0' || ch > '9') ch = getchar();
while (ch >= '0' && ch <= '9') x = (x << 1) + (x << 3) + ch - 48, ch = getchar();
}
struct EDGE { int next, to; } edge[N << 1];
int head[N + 1], col[N + 1], f[1 << M], d[N + 1], sz[N + 1], son[N + 1], fa[N + 1], c[N + 1], ans[N + 1];
int n;
int cnt_edge = 0;
void Add(int u, int v) { edge[ ++ cnt_edge ] = (EDGE) { head[u], v }, head[u] = cnt_edge; }
int max(int x, int y) { return x > y ? x : y; }
void Init() {
d[1] = 1;
fo(i, 2, n) c[i] = c[fa[i]] ^ (1 << col[i]), d[i] = d[fa[i]] + 1;
fd(i, n, 1) {
if (++ sz[i] > sz[son[fa[i]]])
son[fa[i]] = i;
sz[fa[i]] += sz[i];
}
fo(i, 2, n) if (i != son[fa[i]])
Add(fa[i], i);
fo(i, 0, (1 << M) - 1) f[i] = -inf;
}
int Get_d(int x) {
int dep = f[x];
fo(i, 0, M - 1)
dep = max(dep, f[x ^ (1 << i)]);
return dep;
}
int Dfs1(int u) {
int dep = Get_d(c[u]) + d[u];
Fo(i, u) dep = max(dep, Dfs1(edge[i].to));
if (son[u]) dep = max(dep, Dfs1(son[u]));
return dep;
}
void Updata(int x, int dep) { f[x] = max(f[x], dep); }
void Dfs2(int u) {
Updata(c[u], d[u]);
Fo(i, u) Dfs2(edge[i].to);
if (son[u]) Dfs2(son[u]);
}
void Back(int x) { f[x] = -inf; }
void Dfs3(int u) {
Back(c[u]);
Fo(i, u) Dfs3(edge[i].to);
if (son[u]) Dfs3(son[u]);
}
void Solve(int u, int opt) {
ans[u] = 0;
Fo(i, u) Solve(edge[i].to, 1), ans[u] = max(ans[u], ans[edge[i].to]);
if (son[u]) Solve(son[u], 0), ans[u] = max(ans[u], ans[son[u]]);
ans[u] = max(ans[u], Get_d(c[u]) - d[u]);
Updata(c[u], d[u]);
Fo(i, u)
ans[u] = max(ans[u], Dfs1(edge[i].to) - (d[u] << 1)), Dfs2(edge[i].to);
if (opt) Dfs3(u);
}
int main() {
read(n);
fo(i, 2, n)
read(fa[i]), col[i] = getchar() - 'a';
Init();
Solve(1, 0);
fo(i, 1, n) printf("%d ", ans[i]);
return 0;
}