题目链接:https://www.acwing.com/problem/content/357/
题目大意:给定一棵树 , 要求实现以下三种操作: 将某个点进行标记, 将某个点撤销标记 , 求出连接所有标记的点的最短路径
solution
笔者通过查阅他人的Blog , 了解到这道题有一个有趣的性质 , 即把所有有异象石的点按照 dfs 序排序并围成一圈 , 相邻两个节点的距离之和即为询问答案的两倍 , 但均未给出证明 , 经过思考后 , 以下是笔者乱搞出的一种证明 , 不感兴趣的可以跳过:
1.相邻两个节点的路径一定在连接所有标记的点的最短路径中 , 这很好证明 , 在树上 , 两个点之间的路径是唯一的 , 连接所有标记的点的最短路径必然包括 dfs 序相邻两个点的路径
2.连接所有标记的点的最短路径一定都被连接相邻两个节点的路径所包含 , 这也很好证明 , 连接相邻两个节点的路径已经将这些标记点都联通了 , 既然是连接所有标记的点的最短路径 , 一定不会比连接相邻两个节点的路径再多
由上可知 , 所有连接相邻两个节点的路径必然是连接所有标记的点的最短路径 , 且包含了所有路径 , 并且可以把它看做是从 dfs 序出发 , 遍历树上其它标记点 , 再回到起点 , 由树的性质可知 , 每条路径一定走过两遍 , 因此 , 相邻两个节点的距离之和即为询问答案的两倍 , 得证
有了这个结论以后 , 之后就好办了 , 可以用平衡树或 set 之类的进行维护 , 如果插入一个点 (x_i) , 那么找到它相邻的两个点 (l_i) , (r_i) , 询问答案的两倍 (res) 就加上 (dist(l_i , x_i)+dist(x_i , r_i)-dist(l_i , r_i)) , 如果删除一个点 (x_i) , 那么找到它相邻的两个点 (l_i) , (r_i) , 询问答案的两倍 (res) 就减去 (dist(l_i , x_i)+dist(x_i , r_i)-dist(l_i , r_i)) , 如果进行询问 , 那么就输出 (res / 2)
时间复杂度 : (O(mlogn))
code
#include<bits/stdc++.h>
using namespace std;
template <typename T> inline void read(T &FF) {
int RR = 1; FF = 0; char CH = getchar();
for(; !isdigit(CH); CH = getchar()) if(CH == '-') RR = -RR;
for(; isdigit(CH); CH = getchar()) FF = FF * 10 + CH - 48;
FF *= RR;
}
inline void file(string str) {
freopen((str + ".in").c_str(), "r", stdin);
freopen((str + ".out").c_str(), "w", stdout);
}
#define int long long
const int N = 2e5, Log = 21;
int n, m, fa[N][Log + 1], dep[N], ans, dis[N];
int now, si, fst[N], nxt[N], num[N], wi[N], dfn[N];
void add(int u, int v, int w) {
nxt[++now] = fst[u], fst[u] = now, num[now] = v, wi[now] = w;
nxt[++now] = fst[v], fst[v] = now, num[now] = u, wi[now] = w;
}
void pre_lca(int xi) {
dep[xi] = dep[fa[xi][0]] + 1; dfn[xi] = ++si;
for(int i = 1; i <= Log; i++)
fa[xi][i] = fa[fa[xi][i - 1]][i - 1];
for(int i = fst[xi]; i; i = nxt[i])
if(num[i] != fa[xi][0])
dis[num[i]] = dis[xi] + wi[i], fa[num[i]][0] = xi, pre_lca(num[i]);
}
int lca(int xi, int yi) {
if(dep[xi] < dep[yi]) swap(xi, yi);
for(int i = Log; i >= 0; i--)
if(dep[fa[xi][i]] >= dep[yi])
xi = fa[xi][i];;
if(xi == yi) return xi;
for(int i = Log; i >= 0; i--)
if(fa[xi][i] != fa[yi][i])
xi = fa[xi][i], yi = fa[yi][i];
return fa[xi][0];
}
int dist(int xi, int yi) {
return dis[xi] + dis[yi] - 2 * dis[lca(xi, yi)];
}
set<pair<int, int> > pi;
signed main() {
//file("");
int u, v, w, xi; char op;
read(n);
for(int i = 1; i < n; i++)
read(u), read(v), read(w), add(u, v, w);
pre_lca(1);
read(m);
for(int i = 1; i <= m; i++) {
cin >> op;
if(op == '+') {
read(xi);
pi.insert(make_pair(dfn[xi], xi));
if(pi.size() > 1) {
set<pair<int, int> >::iterator qi = pi.find(make_pair(dfn[xi], xi)), pre = qi, succ = qi;
if(pre == pi.begin()) pre = (--pi.end()), succ++;
else if(succ == (--pi.end())) succ = pi.begin(), pre--;
else pre++, succ--;
int li = (*pre).second, ri = (*succ).second;
ans += dist(li, xi) + dist(xi, ri) - dist(li, ri);
}
}
else if(op == '-') {
read(xi);
if(pi.size() > 0) {
set<pair<int, int> >::iterator qi = pi.find(make_pair(dfn[xi], xi)), pre = qi, succ = qi;
if(pre == pi.begin()) pre = (--pi.end()), succ++;
else if(succ == (--pi.end())) succ = pi.begin(), pre--;
else pre++, succ--;
int li = (*pre).second, ri = (*succ).second;
ans -= dist(li, xi) + dist(xi, ri) - dist(li, ri);
}
pi.erase(make_pair(dfn[xi], xi));
}
else cout << ans / 2 << '
';
}
return 0;
}