[2019知临中学csp模拟] 树上相交路径
题目描述
一棵有n个节点的树,有 m 条路径,求相交路径的对数。
两条路径,只要有一个点相同,就算相交。
(mathbb{Solution})
首先有结论:两条路径相交,必有一条路径的 LCA 在另一条路径上。
证明:你可以试试画反例,实际上树的形态限制了这点。
Operations:
记 s[i] 表示 i 是多少条路径的 LCA. 树上差分计算出每条路径上有多少个其他路径的 LCA。
记树上前缀和数组为 ss, 对于每条路径 x, y ,计算 (ss_x +ss_y - 2 imes ss_{LCA})
注意此时应剖除当前路径本身的 LCA 的 s,因为这样的点两条路径都会计算到对方,可以单独考虑。
最后计算作为多条路径 LCA 的点,两两组合的对数是 (dfrac{s_i * (s_i - 1)}{2})。
(mathbf{Code:})
#include <bits/stdc++.h>
#define swap(x, y) (x ^= y ^= x ^= y)
const int N = 1e5 + 10;
#define Rep(i, a, b) for (int i = (a), bb = (b); i <= bb; ++i)
#define S_H(T, i, u) for (int i = T.fl[u], v; v = T.to[i], i; i = T.net[i])
using namespace std;
int n, m, x, y;
struct Tree {
int to[N << 1], net[N << 1], fl[N], len;
inline void inc(int x, int y) { return to[++len] = y, net[len] = fl[x], fl[x] = len, void(); }
} T;
struct Chain { int x, y, Lca; } q[N];
long long s[N], ss[N], sum = 0;
template <class T>
inline void read(T &s) { s = 0; char c = getchar(); for (; !isdigit(c); c = getchar()); for (; isdigit(c); c = getchar()) s = (s << 3) + (s << 1) + c - 48; }
int d[N], top[N], hs[N], si[N], f[N];
inline void Dfs(int u, int fa) { f[u] = fa, d[u] = d[fa] + 1, si[u] = 1; S_H(T, i ,u) { if (v == fa) continue; Dfs(v, u), (si[hs[u]] < si[v] ? hs[u] = v : 0); } }
inline void Dfs_chain(int u, int k) { top[u] = k; if (!hs[u]) return void(); Dfs_chain(hs[u], k); S_H(T, i, u) { if (v == hs[u] || v == f[u]) continue; Dfs_chain(v, v); } }
inline void Get(int u, int fa) { ss[u] = s[u] + ss[fa]; S_H(T, i, u) { if (v == fa) continue; Get(v, u); } }
inline int Lca(int x, int y) {while(top[x]^top[y]){(d[top[x]]<d[top[y]]?swap(x, y):0), x=f[top[x]];}return d[x]<d[y]?x:y;}
int main(void) {
freopen("tree.in", "r", stdin), freopen("tree.out", "w", stdout);
memset(s, 0, sizeof s);
read(n), read(m); Rep(i, 1, n - 1) { read(x), read(y), T.inc(x, y), T.inc(y, x); }
Dfs(1, 0), Dfs_chain(1, 1); Rep(i, 1, m) read(q[i].x), read(q[i].y), q[i].Lca = Lca(q[i].x, q[i].y), ++s[q[i].Lca];
ss[0] = 0; Get(1, 0); Rep(i, 1, m) { int x = q[i].x, y = q[i].y, Lc = q[i].Lca; sum += ss[x] + ss[y] - 2 * ss[Lc]; }
Rep(i, 1, n) if (s[i] > 1) sum += 1LL * s[i] * (s[i] - 1) / 2; printf("%lld
", sum);
return 0;
}