题意:给定一棵树,求任意两点之间的距离。
思路:由于树的特殊性,所以任意两点之间的路径是唯一的。u到v的距离等于dis(u) + dis(v) - 2 * dis(lca(u, v)); 其中dis(u)表示u到根节点的距离。
RMQ求LCA,过程如下,摘自http://dongxicheng.org/structure/lca-rmq/
在线算法DFS+ST描述(思想是:将树看成一个无向图,u和v的公共祖先一定在u与v之间的最短路径上):
(1)DFS:从树T的根开始,进行深度优先遍历(将树T看成一个无向图),并记录下每次到达的顶点。第一个的结点是root(T),每经过一条边都记录它的端点。由于每条边恰好经过2次,因此一共记录了2n-1个结点,用E[1, ... , 2n-1]来表示。
(2)计算R:用R[i]表示E数组中第一个值为i的元素下标,即如果R[u] < R[v]时,DFS访问的顺序是E[R[u], R[u]+1, …, R[v]]。虽然其中包含u的后代,但深度最小的还是u与v的公共祖先。
(3)RMQ:当R[u] ≥ R[v]时,LCA[T, u, v] = RMQ(L, R[v], R[u]);否则LCA[T, u, v] = RMQ(L, R[u], R[v]),计算RMQ。
由于RMQ中使用的ST算法是在线算法,所以这个算法也是在线算法。
代码如下(LCA模板):
//LCA algorithm templet #include <cstdio> #include <iostream> #include <cstring> #include <cmath> #include <cstdlib> #include <algorithm> using namespace std; typedef long long ll; const int maxn = 40010; int tot, head[maxn]; struct Edge { int to, next; int w; }edge[maxn<<1];//edge int Euler[maxn<<1];//Euler sequence int R[maxn];//the R one visit int dep[maxn<<1];//depth int dis[maxn<<1];//dis[i] represent the distance between i and root int cnt;//the counter void init() { tot = 0; cnt = 0; memset(head, -1, sizeof(head)); memset(R, 0, sizeof(R)); memset(dis, 0, sizeof(dis)); } void addedge(int u, int v, int w) { edge[tot].to = v; edge[tot].w = w; edge[tot].next = head[u]; head[u] = tot++; } void dfs(int u, int fa, int depth, int dist) { Euler[++cnt] = u; dis[cnt] = dist; dep[cnt] = depth; R[u] = cnt; for (int i = head[u]; i != -1; i = edge[i].next) { int v = edge[i].to; if (v == fa) continue; dfs(v, u, depth + 1, dist + edge[i].w); Euler[++cnt] = u; dis[cnt] = dist; dep[cnt] = depth; } } int Rmin[maxn * 2][32];//Rmin represent the number(order number) of node void RMQ(int n) { for (int i = 1; i <= n; i++) Rmin[i][0] = i;//initalization int k = (int)log2(n); for (int j = 1; j <= k; j++) { for (int i = 1; i + (1 << j) - 1 <= n; i++)//按照dep来找最小值 Rmin[i][j] = dep[Rmin[i][j - 1]] < dep[Rmin[i + (1 << (j - 1))][j - 1]] ? Rmin[i][j - 1] : Rmin[i + (1 << (j - 1))][j - 1]; } } //找到u和v的距离 int query(int u, int v) { int l = R[u], r = R[v]; if (l > r) swap(l, r); int k = (int)log2(r - l + 1); int tmp = dep[Rmin[l][k]] < dep[Rmin[r - (1 << k) + 1][k]] ? Rmin[l][k] : Rmin[r - (1 << k) + 1][k]; //return Euler[tmp];//这里是返回u和v的公共祖先 return dis[l] + dis[r] - 2 * dis[tmp];//这里返回距离 } int main() { int n, tmp; while (~scanf("%d %d", &n, &tmp)) { init(); int u, v, w; for (int i = 1; i < n; i++) { scanf("%d %d %d %*s", &u, &v, &w); addedge(u, v, w); addedge(v, u, w); } dfs(1, 0, 1, 0); RMQ(cnt); int Q; scanf("%d", &Q); while (Q--) { scanf("%d %d", &u, &v); printf("%d ", query(u, v)); } } return 0; }