LCA
在有根树中,两个节点 u 和 v 的公共祖先中距离最近的那个被称为最近公共祖先(LCA,Lowest Common Ancestor)。
有多种算法解决 LCA 或相关的问题。
基于二分搜索的算法
首先搜索树中各个节点的深度;
const int MAXN = 4e4 + 5; // 最大节点数
const int LOG_N = 60; // 树的最大深度
vector<int> G[MAXN]; // 树
int depth[MAXN]; // 节点深度
int parent[LOG_N][MAXN]; // parent[k][i]表示 i 向上走 2^k 步能到达的节点
void dfs(int pre, int u, int d)
{
parent[0][u] = pre;
depth[u] = d;
for(int i = 0; i < G[u].size(); i++)
{
int v = G[u][i];
if(v != pre) dfs(u, v, d + 1);
}
}
对于任意节点,通过节点和其父节点的信息,都能得到其和父亲的父亲节点的关系,即可以得到向上走 2 步所能到达的节点的值;
那么,同样可以得到向上走 4 步所能到达节点的值;后面同理。
而树的深度很小,所以可以预处理所有点;
void init()
{
int root = 1;
dfs(-1, root, 0);
for(int k = 1; k < LOG_N; k++)
{
for(int i = 1; i <= n; i++)
{
if(parent[k - 1][i] < 0) parent[k][i] = -1;
else parent[k][i] = parent[k - 1][parent[k - 1][i]];
}
}
}
计算 LCA 时,首先让它们到达同一深度,在同时向上搜索最近公共祖先即可。
int lca(int u, int v)
{
if(depth[u] > depth[v]) swap(u, v);
for(int i = 0; i < LOG_N; i++) // u 和 v 向上走到同一深度
{
if((depth[v] - depth[u]) >> i & 1) // 把 (depth[v] - depth[i]) 化成二进制后可以看到,就是找到所有 1 的位置
{
v = parent[i][v];
}
}
if(v == u) return u;
for(int i = LOG_N - 1; i >= 0; i--) // 找 lca
{
if(parent[i][u] != parent[i][v]) // 如果相同,那么一定是公共祖先或公共祖先之上的节点
{
u = parent[i][u];
v = parent[i][v];
}
}
return parent[0][u];
}
Tarjan 离线算法
一道模板题,求二叉树中两个节点的最短距离,就是 dis[u] + dis[v] - 2 * dis[lca(u,v)]
;
Tarjan离线算法,先读入所有查询,直接算出所有答案。
其实就是利用 DFS 遍历二叉树的特性,以及并查集的优化,
首先,从 1 向下搜,一直搜到 8 ,在这过程中,对于查询的边 u - v ,节点 u 对应的 v 已经访问过(如 4 - 2),那么 found(2) 就是 LCA(4, 2) ;
搜到 8 ,会回到 4 -> 9 -> 4 -> 2,再去搜 5 ,如果查询的节点是 (5 - 8) 5 对应的 8 已被访问过,那么 LCA(5, 8) = found(8),因为到现在为止的 DFS 都在 2 这个节点之下,所以只要没回到 1, found(2) = 2 保持不变,即 LCA(5, 8) = found(2) = 2,后面的同理;
所以关键就是遍历完一个节点的所有子树之后在去指定这个节点的父亲节点;
#include<cstdio>
#include<cstring>
#include<vector>
#include<algorithm>
typedef long long ll;
const int INF = 1e9;
const int MAXN = 4e4 + 10;
using namespace std;
typedef pair<int, int> P;
int n, m;
int p[MAXN]; // 并查集祖先节点
int q[MAXN]; // 对应第几次查询的 lca
int ex[MAXN], ey[MAXN]; // 记录查询的边
int vis[MAXN]; // 标记数组
int dis[MAXN]; // 离根节点的距离
vector<P> G[MAXN];
vector<P> edges[MAXN];
int found(int x)
{
return x == p[x] ? x : (p[x] = found(p[x]));
}
void tarjan(int pre, int u, int len)
{
vis[u] = 1;
dis[u] = dis[pre] + len;
for(int i = 0; i < edges[u].size(); i++)
{
P v = edges[u][i];
if(vis[v.first]) q[v.second] = found(v.first);
}
for(int i = 0; i < G[u].size(); i++)
{
P v = G[u][i];
if(v.first != pre)
{
tarjan(u, v.first, v.second);
p[v.first] = u;
}
}
}
int main()
{
int T;
scanf("%d", &T);
while(T--)
{
scanf("%d%d", &n, &m);
memset(vis, 0, sizeof vis);
for(int i = 0; i <= n; i++)
{
p[i] = i;
vis[i] = 0;
G[i].clear();
edges[i].clear();
}
for(int i = 1; i < n; i++)
{
int x, y, z;
scanf("%d%d%d", &x, &y, &z);
G[x].push_back(P(y, z));
G[y].push_back(P(x, z));
}
for(int i = 0; i < m; i++)
{
int x, y;
scanf("%d%d", &x, &y);
ex[i] = x; ey[i] = y;
edges[x].push_back(P(y, i));
edges[y].push_back(P(x, i));
}
tarjan(0, 1, 0);
for(int i = 0; i < m; i++)
{
printf("%d
", dis[ex[i]] + dis[ey[i]] - 2 * dis[q[i]]);
}
}
return 0;
}