- 描述
对于有根树T的两个节点u和v,最近公共祖先LCA(T,u,v)表示一个节点x满足x是u,v的公共祖先且x的深度尽可能大。
- 算法
求解LCA问题主要有三种解法,分别是暴力搜索,Tanjar算法,最后一种是转化为RMQ问题,用DFS+ST算法来求解
-
- 暴力搜索
如果数据量不大的时候可以采用暴力搜索法。先将节点u的祖先节点全部标记出来,然后顺着节点v沿着父亲节点的方向向上遍历,直到遍历到一个被标记的节点,这个节点即为所求节点。或者分别获取u,v到根节点的路径P1,P2,可以将这两条路径看做两个两个链表,问题即转化为求两个无环链表的交点。
eg hihocder 1062
#include<bits/stdc++.h> using namespace std; map<string,string> maps; set<string> fas; int main(){ int n; string str1,str2; cin>>n; for(int i=0; i<n; i++){ cin>>str1>>str2; maps[str2] = str1; } cin>>n; while(n--){ cin>>str1>>str2; fas.clear(); fas.insert(str1); while(true){ if(maps.find(str1) == maps.end()) break; fas.insert(maps[str1]); str1 = maps[str1]; } int flag = 0; while(true){ if(fas.find(str2) != fas.end()){ flag = 1; break; } if(maps.find(str2) == maps.end()) break; str2 = maps[str2]; } if(flag == 1) cout<<str2<<endl; else cout<<-1<<endl; } return 0; }
-
- Tanjar算法
Tanjar算法是一种离线算法,所谓离线算法是指先读入所有查询之后,再统一处理,得到所有结果。Tanjar算法的基础是:设t是u,v的公共祖先之一,且u和v分别位于t的两颗不同的子树,则t一定是u,v的最近公共祖先。所以,Tanjar算法的基本流程是:从根节点开始,采用深度优先搜索的方法遍历整颗树。对于当前节点t,设置t的祖先节点设为t,再对该节点的每颗子树进行遍历,每搜索完一颗子树的时候即可处理完子树内部的查询,再把子节点的祖先节点设为t。当所有子树遍历完成的时候,再处理关于t结点的LCA询问v,如果v已经被访问过,由于进行的是DFS,此时t,v的公共祖先一定还没确定,而且这个公共祖先就是v当前的祖先节点。算法为代码如下:
Tanjar(u){ 设置u的祖先节点为u 对于u的孩子节点v: Tanjar(v) Union(u,v) 对于关于u的查询节点v: 如果v已经被访问过,则LCA(u,v) = v的祖先节点 将u以及所有子节点的祖先节点设为u的父亲 }
eg hihocoder1067
#include<iostream> #include<map> #include<cstring> #include<vector> using namespace std; const int maxm = 100010; int n,cnt; int fa[maxm]; map<string,int> maps; vector<vector<int> > graph; vector<pair<int,int> > query[maxm]; string names[maxm]; int res[maxm]; int find(int pos){ return pos==fa[pos] ? pos : fa[pos] = find(fa[pos]); } void solve(int node){ fa[node] = node; int tmp; for(int i=0; i<graph[node].size(); i++){ tmp = graph[node][i]; solve(tmp); fa[tmp] = node; } for(int i=0; i<query[node].size(); i++){ pair<int,int> ps = query[node][i]; if(fa[ps.first] != -1) res[ps.second] = find(ps.first); } } int main(){ cin>>n; string str1,str2; int tp1,tp2; memset(fa,-1,sizeof(fa)); cnt = 0; for(int i=0; i<n; i++){ cin>>str1>>str2; if(maps.find(str1) == maps.end()){ names[cnt] = str1; maps[str1] = cnt++; graph.push_back(vector<int>()); } if(maps.find(str2) == maps.end()){ names[cnt] = str2; maps[str2] = cnt++; graph.push_back(vector<int>()); } graph[maps[str1]].push_back(maps[str2]); } cin>>n; int sp = 0; for(int i=0; i<n; i++){ cin>>str1>>str2; tp1 = maps[str1]; tp2 = maps[str2]; query[tp1].push_back(make_pair(tp2,sp)); query[tp2].push_back(make_pair(tp1,sp)); sp++; } solve(0); int ps = 0; while(sp--){ cout<<names[res[ps++]]<<endl; } return 0; }
-
- RMQ(在线算法)
最后一种解法是转化为RMQ问题,用DFS+ST算法的方式来求解。这种解法的基本思想是:如果将树看成无向图,则u,v的最近公共祖先一定在u,v的最短路径之上。该解法的思路为:
- DFS。从树T的根开始,进行DFS,并记录下每次到达的顶点,每经过一条边都记录它的端点,由于每条边恰好经过2次,因此一共记录了2n-1个结点,用E[1, ... , 2n-1]来表示,E中存储的为节点编号
- 计算R。再遍历的过程中,记录每一个节点第一次被遍历到的次序,用R来表示,即R[i]表示节点E中出现i的最小下标
- 求解。如果R[u] < R[v],则u,v之间的最短路径为E[R[u]],E[R[u]+1]...E[R[v]],则该路径中深度最小的肯定是u,v的最小公共祖先。R[u]>R[v]的时候同理。
- 通过RMQ求解深度最小的节点。通过数组L来表示E中每一个节点对应的深度,所以数组L再DFS遍历的时候也可以求得。查找u,v的LCA的时,只需查找L[R[u]],L[R[u]+1]...L[R[v]]折一段中深度最小的元素的下标即可,通过该下标即可从E数组中得到对应的节点
eg hihocoder 1069
#include<iostream> #include<map> #include<cstring> #include<vector> using namespace std; const int MAXM = 100010; const int LEN = 20; int cnt,cnt1; map<string,int> maps; string name[MAXM]; vector<int> graph[MAXM]; vector<pair<int,int> > query[MAXM]; int H[MAXM],L[2*MAXM],list[2*MAXM]; int dp[MAXM][LEN]; void dfs(int pos,int deep){ if(H[pos] == -1){ H[pos] = cnt1; } L[cnt1] = deep; list[cnt1++] = pos; for(int i=0; i<graph[pos].size(); i++){ int sp = graph[pos][i]; dfs(sp,deep+1); L[cnt1] = deep; list[cnt1++] = pos; } } void init(){ memset(H,-1,sizeof(H)); dfs(0,0); for(int i=0; i<cnt1; i++) dp[i][0] = i; for(int j=1; j<LEN; j++){ int lens = 1<<j; for(int i=0; i<cnt1; i++){ if(i+lens >= cnt1) break; int sp1 = dp[i][j-1]; int sp2 = dp[i+(1<<(j-1))][j-1]; if(L[sp1] < L[sp2]){ dp[i][j] = sp1; }else dp[i][j] = sp2; } } } int querys(int pos1, int pos2){ if(pos2 < pos1) swap(pos1,pos2); int len = pos2-pos1+1; int j; for(j=0; (1<<j)<=len; j++); j--; if((1<<j) == len) return dp[pos1][j]; int sp1 = dp[pos1][j]; int sp2 = querys(pos1+(1<<j),pos2); if(L[sp1] < L[sp2]) return sp1; return sp2; } int main(){ int n,m; cin>>n; cnt = 0; cnt1 = 0; string str1,str2; while(n--){ cin>>str1>>str2; if(maps.find(str1)==maps.end()){ maps[str1]=cnt; name[cnt]=str1; cnt++;} if(maps.find(str2)==maps.end()){ maps[str2]=cnt; name[cnt]=str2; cnt++;} int ps1 = maps[str1]; int ps2 = maps[str2]; graph[ps1].push_back(ps2); } init(); cin>>m; while(m--){ cin>>str1>>str2; int ps1 = maps[str1]; int ps2 = maps[str2]; int sp = querys(H[ps1],H[ps2]); cout<<name[list[sp]]<<endl; } return 0; }