树上任意两点的最近祖先,必定就是这两个节点的最短路径上深度最小的那个点。
例如:下图中,节点7和5,其最短路径为7--4--1--5, 这条路径上深度最小的点为节点1,其深度为1.节点1即为节点7和5的LCA。
因此,要找到任意两个节点的LCA,只需要先找到上述最短路径,再找到最短路径中深度最小的点。而这下面所述LCA在线算法所做的事。
LCA在线算法描述(以上图为例):
1.获得“最短路径”(并不是真正的一条路径,包含其他节点,但不影响算法的正确性)
采用DFS遍历整棵树,得到以下数据:
(1)遍历序列p:0 1 3 1 4 7 4 8 4 1 5 1 0 2 6 2 0
(2)各节点的深度序列 depth: 0 1 1 2 2 2 2 3 3
(3)各节点在序列p中首次出现的位置序列pos: 0 1 13 2 4 10 14 5 7
有了以上数据,假设现在我们要求节点7和5的最短路径,我们可以这样做:
(1)首先,从pos序列中获得节点7和节点5在p序列中第一次出现的位置分别为:pos[7] = 5, pos[5] = 10;
(2)得到p序列中[5, 10]这一段子序列s:7 4 8 4 1 5
(3)s序列中深度最小的点即节点1就是我们要找的节点7和节点5的LCA。
注意到,此时的s序列并非是从节点7到节点5的一条最短路径,它除了包含7到5的最短路径上的节点外,还包含了一些其他的节点,但这些其他的节点都是以节点1为根的子树上的节点,他们的深度都比节点1大,不影响我们算法对正确结果的求解。
2.如何快速的获得一段序列中深度最小的节点
求解区间最值的问题是我们所熟悉的经典的RMQ问题,用RMQ-ST算法即可。由于RMQ-ST算法是在线的,故我们的LCA算法也是在线的。
下面是我的代码实现:
1 #include <iostream> 2 #include <string> 3 #include <map> 4 #include <vector> 5 #include <algorithm> 6 #include <cmath> 7 8 using namespace std; 9 10 #define MAXN 100005 11 12 map<string, int> mp; 13 string name[2*MAXN]; 14 vector<int> v[2*MAXN]; 15 int p[4*MAXN], depth[2*MAXN], pos[2*MAXN]; 16 int pre_cal[4*MAXN][20]; 17 int cnt, n, m; 18 19 void dfs(int i, int d) 20 { 21 pos[i] = cnt; 22 p[cnt++] = i; 23 depth[i] = d; 24 if(v[i].empty()) return; 25 for(int j=0; j<v[i].size(); ++j) 26 { 27 dfs(v[i][j], d+1); 28 p[cnt++] = i; 29 } 30 } 31 32 void rmq() 33 { 34 for(int i=0; i<4*MAXN; ++i) pre_cal[i][0] = i; 35 for(int j=1; (1<<(j-1))<4*MAXN; ++j) 36 for(int i=0; i+(1<<(j-1))<4*MAXN; ++i) 37 pre_cal[i][j] = depth[p[pre_cal[i][j-1]]]<depth[p[pre_cal[i+(1<<(j-1))][j-1]]]?pre_cal[i][j-1]:pre_cal[i+(1<<(j-1))][j-1]; 38 } 39 40 string lca(int a, int b) 41 { 42 int k = floor(log(b-a+1)/log(2)); 43 int x = pre_cal[a][k], y = pre_cal[b-(1<<k)+1][k]; 44 return depth[p[x]]<depth[p[y]]?name[p[x]]:name[p[y]]; 45 } 46 47 void init() 48 { 49 cnt = 0; 50 mp.clear(); 51 for(int i=0; i<2*MAXN; ++i) v[i].clear(); 52 } 53 54 int main() 55 { 56 string name1, name2; 57 while(cin>>n) 58 { 59 init(); 60 while(n--) 61 { 62 cin>>name1>>name2; 63 if(mp.find(name1)==mp.end()) 64 { 65 mp[name1] = cnt; 66 name[cnt++] = name1; 67 } 68 if(mp.find(name2)==mp.end()) 69 { 70 mp[name2] = cnt; 71 name[cnt++] = name2; 72 } 73 v[mp[name1]].push_back(mp[name2]); 74 } 75 cnt = 0; 76 dfs(0, 0); 77 rmq(); 78 cin>>m; 79 while(m--) 80 { 81 cin>>name1>>name2; 82 int a = mp[name1], b = mp[name2]; 83 cout<<(pos[a]<pos[b]?lca(pos[a], pos[b]):lca(pos[b], pos[a]))<<endl; 84 } 85 } 86 87 return 0; 88 }
题目链接:http://hihocoder.com/problemset/problem/1069