题目描述
给你一个包含 (n(1 le n le 10^6)) 个节点的树,节点编号从 (1) 到 (n),根节点的编号为 (1)。每一个节点都有一个颜色,我们用 (c_i) 来表示节点 (i) 的颜色。
接下来有 (m(1 le m le 10^6)) 次询问,每一次询问都会给你两个整数 (u) 和 (c)((1 le u,c le n)),对于每一次询问,你需要回答:以节点 (u) 为根节点的子树中颜色为 (c) 的节点数量。
题解
大多数人都知道 DSU(并查集,Disjoint Set Union)但是什么是 “dsu on tree”(树上启发式合并,直译为“书上的并查集”)?
什么是树上启发式合并(dsu on tree)?
使用 dsu on tree 我们可以回答如下的问题:
在 (O(n log n)) 时间复杂度内计算所有的节点
v
的子树中存在多少个点满足某一性质。
所以对于这道问题我们就是求解:
给你一棵树,每一个节点都有一个颜色。问题是询问 以节点 v
为根的子树中存在多少个点的颜色为 c
?
暴力解法 (O(n^2))
#include <bits/stdc++.h>
using namespace std;
const int maxn = 1000010;
vector<int> g[maxn];
int n, m, sz[maxn], c[maxn];
struct Query {
int c, ans;
} query[maxn];
vector<int> qid[maxn];
int cnt[maxn];
void add(int u, int p, int x) {
cnt[ c[u] ] += x;
for (vector<int>::iterator it = g[u].begin(); it != g[u].end(); it ++) {
int v = *it;
if (v != p) add(v, u, x);
}
}
void dfs(int u, int p) {
add(u, p, 1);
for (vector<int>::iterator it = qid[u].begin(); it != qid[u].end(); it ++) {
int id = *it;
query[id].ans = cnt[ query[id].c ];
}
add(u, p, -1);
for (vector<int>::iterator it = g[u].begin(); it != g[u].end(); it ++) {
int v = *it;
if (v != p) dfs(v, u);
}
}
int main() {
cin >> n;
for (int i = 1; i <= n; i ++) cin >> c[i];
for (int i = 1; i < n; i ++) {
int a, b;
cin >> a >> b;
g[a].push_back(b);
g[b].push_back(a);
}
cin >> m;
for (int i = 0; i < m; i ++) {
int u, c;
cin >> u >> query[i].c;
qid[u].push_back(i);
}
dfs(1, -1);
for (int i = 0; i < m; i ++)
cout << query[i].ans << endl;
return 0;
}
1. 基于dsu on tree的解法1 (O(n log^2 n))
这个解法采用了 dsu on tree 的思想,将每科子树对应的颜色和数量都存在一个 map 中。父节点复用其重儿子的 map。时间复杂度为遍历的节点数 (n log n) 乘以 map 中获得每个元素的时间 (log n) 等于 (O(n log^2 n))。
#include <bits/stdc++.h>
using namespace std;
const int maxn = 1000010;
vector<int> g[maxn];
int n, m, sz[maxn], c[maxn];
struct Query {
int c, ans;
} query[maxn];
vector<int> qid[maxn];
int cnt[maxn];
map<int, int> mp[maxn];
int mpcnt, mpid[maxn];
void getsz(int u, int p) { // 计算sz[u] -- 以u为根节点的子树大小(包含节点个数)
sz[u] ++;
for (vector<int>::iterator it = g[u].begin(); it != g[u].end(); it ++) {
int v = *it;
if (v != p) getsz(v, u);
}
}
void dfs(int u, int p) {
int mx = -1, bigSon = -1; // mx表示重儿子的sz,bigSon表示重儿子的编号
for (vector<int>::iterator it = g[u].begin(); it != g[u].end(); it ++) {
int v = *it;
if (v != p) {
dfs(v, u);
if (sz[v] > mx) {
mx = sz[v];
bigSon = v;
}
}
}
if (bigSon == -1) // 叶子节点
mpid[u] = ++ mpcnt;
else
mpid[u] = mpid[bigSon];
mp[ mpid[u] ][ c[u] ] ++;
for (vector<int>::iterator it = g[u].begin(); it != g[u].end(); it ++) {
int v = *it;
if (v != p && v != bigSon) {
for (map<int, int>::iterator it2 = mp[ mpid[v] ].begin(); it2 != mp[ mpid[v] ].end(); it2 ++) {
pair<int, int> x = *it2;
mp[ mpid[u] ][ x.first ] += x.second;
}
mp[ mpid[v] ].clear();
}
}
for (vector<int>::iterator it = qid[u].begin(); it != qid[u].end(); it ++) {
int id = *it;
query[id].ans = mp[ mpid[u] ][ query[id].c ];
}
}
int main() {
cin >> n;
for (int i = 1; i <= n; i ++) cin >> c[i];
for (int i = 1; i < n; i ++) {
int a, b;
cin >> a >> b;
g[a].push_back(b);
g[b].push_back(a);
}
cin >> m;
for (int i = 0; i < m; i ++) {
int u, c;
cin >> u >> query[i].c;
qid[u].push_back(i);
}
dfs(1, -1);
for (int i = 0; i < m; i ++)
cout << query[i].ans << endl;
return 0;
}
2. 基于dsu on tree的解法2 (O(n log n))
方法2使用vector代替map,公用一个cnt数组,时间复杂度降到 (O(n log n))。
#include <bits/stdc++.h>
using namespace std;
const int maxn = 1000010;
vector<int> g[maxn];
int n, m, sz[maxn], c[maxn];
struct Query {
int c, ans;
} query[maxn];
vector<int> qid[maxn];
int cnt[maxn];
vector<int> vec[maxn];
int veccnt, vecid[maxn];
void getsz(int u, int p) { // 计算sz[u] -- 以u为根节点的子树大小(包含节点个数)
sz[u] ++;
for (vector<int>::iterator it = g[u].begin(); it != g[u].end(); it ++) {
int v = *it;
if (v != p) getsz(v, u);
}
}
void dfs(int u, int p, bool keep) {
int mx = -1, bigSon = -1; // mx表示重儿子的sz,bigSon表示重儿子的编号
for (vector<int>::iterator it = g[u].begin(); it != g[u].end(); it ++) {
int v = *it;
if (v != p && sz[v] > mx) {
mx = sz[v];
bigSon = v;
}
}
for (vector<int>::iterator it = g[u].begin(); it != g[u].end(); it ++) {
int v = *it;
if (v != p && v != bigSon) dfs(v, u, false);
}
if (bigSon == -1) // 叶子节点
vecid[u] = ++ veccnt;
else {
dfs(bigSon, u, true);
vecid[u] = vecid[bigSon];
}
vec[ vecid[u] ].push_back(u);
cnt[ c[u] ] ++;
for (vector<int>::iterator it = g[u].begin(); it != g[u].end(); it ++) {
int v = *it;
if (v != p && v != bigSon) {
for (vector<int>::iterator it2 = vec[ vecid[v] ].begin(); it2 != vec[ vecid[v] ].end(); it2 ++) {
int x = *it2;
vec[ vecid[u] ].push_back(x);
cnt[ c[x] ] ++;
}
vec[ vecid[v] ].clear();
}
}
for (vector<int>::iterator it = qid[u].begin(); it != qid[u].end(); it ++) {
int id = *it;
query[id].ans = cnt[ query[id].c ];
}
if (!keep) { // 需要还原
for (vector<int>::iterator it = vec[ vecid[u] ].begin(); it != vec[ vecid[u] ].end(); it ++) {
cnt[ c[*it] ] --;
}
}
}
int main() {
cin >> n;
for (int i = 1; i <= n; i ++) cin >> c[i];
for (int i = 1; i < n; i ++) {
int a, b;
cin >> a >> b;
g[a].push_back(b);
g[b].push_back(a);
}
cin >> m;
for (int i = 0; i < m; i ++) {
int u, c;
cin >> u >> query[i].c;
qid[u].push_back(i);
}
dfs(1, -1, false);
for (int i = 0; i < m; i ++)
cout << query[i].ans << endl;
return 0;
}
3. 轻儿子-重儿子分解形式 (O(n log n))
这种格式开了一个 bool 类型的 big 数组,(big[u]) 用于标记当前节点 (u) 是不是某一个节点的重儿子,重儿子不需要还原。 这一步操作真的非常神奇!
虽然都是dsu on tree的实现,但是这种方式比前两种方式要更省空间(省的不知道哪里去了)。
#include <bits/stdc++.h>
using namespace std;
const int maxn = 1000010;
vector<int> g[maxn];
int n, m, sz[maxn], c[maxn];
struct Query {
int c, ans;
} query[maxn];
vector<int> qid[maxn];
int cnt[maxn];
bool big[maxn];
void add(int u, int p, int x) {
cnt[ c[u] ] += x;
for (vector<int>::iterator it = g[u].begin(); it != g[u].end(); it ++) {
int v = *it;
if (v != p && !big[v])
add(v, u, x);
}
}
void dfs(int u, int p, bool keep) {
int mx = -1, bigSon = -1; // mx表示重儿子的sz,bigSon表示重儿子的编号
for (vector<int>::iterator it = g[u].begin(); it != g[u].end(); it ++) {
int v = *it;
if (v != p && sz[v] > mx) {
mx = sz[v];
bigSon = v;
}
}
for (vector<int>::iterator it = g[u].begin(); it != g[u].end(); it ++) {
int v = *it;
if (v != p && v != bigSon) dfs(v, u, false);
}
if (bigSon != -1) {
dfs(bigSon, u, true);
big[bigSon] = true;
}
add(u, p, 1);
for (vector<int>::iterator it = qid[u].begin(); it != qid[u].end(); it ++) {
int id = *it;
query[id].ans = cnt[ query[id].c ];
}
if (bigSon != -1)
big[bigSon] = 0;
if (!keep)
add(u, p, -1);
}
int main() {
cin >> n;
for (int i = 1; i <= n; i ++) cin >> c[i];
for (int i = 1; i < n; i ++) {
int a, b;
cin >> a >> b;
g[a].push_back(b);
g[b].push_back(a);
}
cin >> m;
for (int i = 0; i < m; i ++) {
int u, c;
cin >> u >> query[i].c;
qid[u].push_back(i);
}
dfs(1, -1, false);
for (int i = 0; i < m; i ++)
cout << query[i].ans << endl;
return 0;
}