https://daniu.luogu.org/problem/show?pid=2146
树剖裸题。树上(指题目给的树)每个结点有安装和未安装两种状态,对应计数1和0。树链剖分到线段树后,维护计数的和。
安装操作即把节点x到根的路径上所有节点的计数改为1,统计更改了多少节点并输出。
删除操作即把节点x及它的所有后代节点计数改为0,统计更改了多少节点并输出。
#include <iostream> #include <string> #include <vector> #define maxn 100010 using namespace std; int n, m; namespace seg { struct node { int ln, rn, mn; int len; int cnt, mark; }; node nds[maxn * 4]; void push_down(int p) { if (nds[p].mark != -1 && nds[p].len > 1) nds[p * 2].mark = nds[p * 2 + 1].mark = nds[p].mark; nds[p].mark = -1; } void pull_up(int p) { if (nds[p].mark != -1) nds[p].cnt = nds[p].mark * nds[p].len; else if (nds[p].len > 1) nds[p].cnt = nds[p * 2].cnt + nds[p * 2 + 1].cnt; } void init(int l, int r, int p = 1) { node &nd = nds[p]; nd.ln = l; nd.rn = r; nd.mn = (l + r) / 2; nd.len = r - l + 1; nd.cnt = 0; nd.mark = -1; if (l != r) { init(nd.ln, nd.mn, p * 2); init(nd.mn + 1, nd.rn, p * 2 + 1); } } void set(int l, int r, int val, int p = 1) { node &nd = nds[p]; if (nd.ln == l && nd.rn == r) { nd.mark = val; } else { push_down(p); if (l <= nd.mn) set(l, min(nd.mn, r), val, p * 2); else pull_up(p * 2); if (nd.mn + 1 <= r) set(max(l, nd.mn + 1), r, val, p * 2 + 1); else pull_up(p * 2 + 1); } pull_up(p); } int query(int l, int r, int p = 1) { node &nd = nds[p]; if (nd.mark != -1) { return nd.mark * (r - l + 1); } else if (nd.ln == l && nd.rn == r) { return nd.cnt; } else { int ans = 0; if (l <= nd.mn) ans += query(l, min(nd.mn, r), p * 2); if (nd.mn + 1 <= r) ans += query(max(l, nd.mn + 1), r, p * 2 + 1); return ans; } } } int parent[maxn]; vector<int> child[maxn]; void add_child(int p, int c) { parent[c] = p; child[p].push_back(c); } int depth[maxn], size[maxn], heavy[maxn]; void dfs1(int k = 0, int d = 0) { size[k] = 1; depth[k] = d; heavy[k] = -1; int max_size = 0; for (int i = 0; i < child[k].size(); ++i) { dfs1(child[k][i], d + 1); size[k] += size[child[k][i]]; if (max_size < size[child[k][i]]) { max_size = size[child[k][i]]; heavy[k] = child[k][i]; } } } int top[maxn], hash[maxn]; int cnt = 1; void dfs2(int k = 0) { hash[k] = cnt++; if (child[k].size()) { top[heavy[k]] = top[k]; dfs2(heavy[k]); for (int i = 0; i < child[k].size(); ++i) { if (child[k][i] != heavy[k]) { top[child[k][i]] = child[k][i]; dfs2(child[k][i]); } } } } int install(int k) { int ans = 0; while (true) { ans += (hash[k] - hash[top[k]] + 1) - seg::query(hash[top[k]], hash[k]); seg::set(hash[top[k]], hash[k], 1); if (top[k] == 0) break; k = parent[top[k]]; } return ans; } int uninstall(int k) { int ans = seg::query(hash[k], hash[k] + size[k] - 1); seg::set(hash[k], hash[k] + size[k] - 1, 0); return ans; } int main() { ios::sync_with_stdio(false); cin >> n; int a; for (int i = 1; i <= n - 1; ++i) { cin >> a; add_child(a, i); } dfs1(); dfs2(); seg::init(1, n); cin >> m; string str; while (m--) { cin >> str >> a; switch (str[0]) { case 'i': cout << install(a) << endl; break; case 'u': cout << uninstall(a) << endl; break; } } return 0; }