- 描述
对于有根树T的两个节点u和v,最近公共祖先LCA(T,u,v)表示一个节点x满足x是u,v的公共祖先且x的深度尽可能大。
- 算法
求解LCA问题主要有三种解法,分别是暴力搜索,Tanjar算法,最后一种是转化为RMQ问题,用DFS+ST算法来求解
-
- 暴力搜索
如果数据量不大的时候可以采用暴力搜索法。先将节点u的祖先节点全部标记出来,然后顺着节点v沿着父亲节点的方向向上遍历,直到遍历到一个被标记的节点,这个节点即为所求节点。或者分别获取u,v到根节点的路径P1,P2,可以将这两条路径看做两个两个链表,问题即转化为求两个无环链表的交点。
eg hihocder 1062
#include<bits/stdc++.h>
using namespace std;
map<string,string> maps;
set<string> fas;
int main(){
int n;
string str1,str2;
cin>>n;
for(int i=0; i<n; i++){
cin>>str1>>str2;
maps[str2] = str1;
}
cin>>n;
while(n--){
cin>>str1>>str2;
fas.clear();
fas.insert(str1);
while(true){
if(maps.find(str1) == maps.end())
break;
fas.insert(maps[str1]);
str1 = maps[str1];
}
int flag = 0;
while(true){
if(fas.find(str2) != fas.end()){
flag = 1;
break;
}
if(maps.find(str2) == maps.end())
break;
str2 = maps[str2];
}
if(flag == 1)
cout<<str2<<endl;
else cout<<-1<<endl;
}
return 0;
}
-
- Tanjar算法
Tanjar算法是一种离线算法,所谓离线算法是指先读入所有查询之后,再统一处理,得到所有结果。Tanjar算法的基础是:设t是u,v的公共祖先之一,且u和v分别位于t的两颗不同的子树,则t一定是u,v的最近公共祖先。所以,Tanjar算法的基本流程是:从根节点开始,采用深度优先搜索的方法遍历整颗树。对于当前节点t,设置t的祖先节点设为t,再对该节点的每颗子树进行遍历,每搜索完一颗子树的时候即可处理完子树内部的查询,再把子节点的祖先节点设为t。当所有子树遍历完成的时候,再处理关于t结点的LCA询问v,如果v已经被访问过,由于进行的是DFS,此时t,v的公共祖先一定还没确定,而且这个公共祖先就是v当前的祖先节点。算法为代码如下:
Tanjar(u){
设置u的祖先节点为u
对于u的孩子节点v:
Tanjar(v)
Union(u,v)
对于关于u的查询节点v:
如果v已经被访问过,则LCA(u,v) = v的祖先节点
将u以及所有子节点的祖先节点设为u的父亲
}
eg hihocoder1067
#include<iostream>
#include<map>
#include<cstring>
#include<vector>
using namespace std;
const int maxm = 100010;
int n,cnt;
int fa[maxm];
map<string,int> maps;
vector<vector<int> > graph;
vector<pair<int,int> > query[maxm];
string names[maxm];
int res[maxm];
int find(int pos){
return pos==fa[pos] ? pos : fa[pos] = find(fa[pos]);
}
void solve(int node){
fa[node] = node;
int tmp;
for(int i=0; i<graph[node].size(); i++){
tmp = graph[node][i];
solve(tmp);
fa[tmp] = node;
}
for(int i=0; i<query[node].size(); i++){
pair<int,int> ps = query[node][i];
if(fa[ps.first] != -1)
res[ps.second] = find(ps.first);
}
}
int main(){
cin>>n;
string str1,str2;
int tp1,tp2;
memset(fa,-1,sizeof(fa));
cnt = 0;
for(int i=0; i<n; i++){
cin>>str1>>str2;
if(maps.find(str1) == maps.end()){
names[cnt] = str1;
maps[str1] = cnt++;
graph.push_back(vector<int>());
}
if(maps.find(str2) == maps.end()){
names[cnt] = str2;
maps[str2] = cnt++;
graph.push_back(vector<int>());
}
graph[maps[str1]].push_back(maps[str2]);
}
cin>>n;
int sp = 0;
for(int i=0; i<n; i++){
cin>>str1>>str2;
tp1 = maps[str1];
tp2 = maps[str2];
query[tp1].push_back(make_pair(tp2,sp));
query[tp2].push_back(make_pair(tp1,sp));
sp++;
}
solve(0);
int ps = 0;
while(sp--){
cout<<names[res[ps++]]<<endl;
}
return 0;
}
-
- RMQ(在线算法)
最后一种解法是转化为RMQ问题,用DFS+ST算法的方式来求解。这种解法的基本思想是:如果将树看成无向图,则u,v的最近公共祖先一定在u,v的最短路径之上。该解法的思路为:
- DFS。从树T的根开始,进行DFS,并记录下每次到达的顶点,每经过一条边都记录它的端点,由于每条边恰好经过2次,因此一共记录了2n-1个结点,用E[1, ... , 2n-1]来表示,E中存储的为节点编号
- 计算R。再遍历的过程中,记录每一个节点第一次被遍历到的次序,用R来表示,即R[i]表示节点E中出现i的最小下标
- 求解。如果R[u] < R[v],则u,v之间的最短路径为E[R[u]],E[R[u]+1]...E[R[v]],则该路径中深度最小的肯定是u,v的最小公共祖先。R[u]>R[v]的时候同理。
- 通过RMQ求解深度最小的节点。通过数组L来表示E中每一个节点对应的深度,所以数组L再DFS遍历的时候也可以求得。查找u,v的LCA的时,只需查找L[R[u]],L[R[u]+1]...L[R[v]]折一段中深度最小的元素的下标即可,通过该下标即可从E数组中得到对应的节点
eg hihocoder 1069
#include<iostream>
#include<map>
#include<cstring>
#include<vector>
using namespace std;
const int MAXM = 100010;
const int LEN = 20;
int cnt,cnt1;
map<string,int> maps;
string name[MAXM];
vector<int> graph[MAXM];
vector<pair<int,int> > query[MAXM];
int H[MAXM],L[2*MAXM],list[2*MAXM];
int dp[MAXM][LEN];
void dfs(int pos,int deep){
if(H[pos] == -1){
H[pos] = cnt1;
}
L[cnt1] = deep;
list[cnt1++] = pos;
for(int i=0; i<graph[pos].size(); i++){
int sp = graph[pos][i];
dfs(sp,deep+1);
L[cnt1] = deep;
list[cnt1++] = pos;
}
}
void init(){
memset(H,-1,sizeof(H));
dfs(0,0);
for(int i=0; i<cnt1; i++)
dp[i][0] = i;
for(int j=1; j<LEN; j++){
int lens = 1<<j;
for(int i=0; i<cnt1; i++){
if(i+lens >= cnt1) break;
int sp1 = dp[i][j-1];
int sp2 = dp[i+(1<<(j-1))][j-1];
if(L[sp1] < L[sp2]){
dp[i][j] = sp1;
}else dp[i][j] = sp2;
}
}
}
int querys(int pos1, int pos2){
if(pos2 < pos1) swap(pos1,pos2);
int len = pos2-pos1+1;
int j;
for(j=0; (1<<j)<=len; j++);
j--;
if((1<<j) == len) return dp[pos1][j];
int sp1 = dp[pos1][j];
int sp2 = querys(pos1+(1<<j),pos2);
if(L[sp1] < L[sp2]) return sp1;
return sp2;
}
int main(){
int n,m;
cin>>n;
cnt = 0;
cnt1 = 0;
string str1,str2;
while(n--){
cin>>str1>>str2;
if(maps.find(str1)==maps.end()){ maps[str1]=cnt; name[cnt]=str1; cnt++;}
if(maps.find(str2)==maps.end()){ maps[str2]=cnt; name[cnt]=str2; cnt++;}
int ps1 = maps[str1];
int ps2 = maps[str2];
graph[ps1].push_back(ps2);
}
init();
cin>>m;
while(m--){
cin>>str1>>str2;
int ps1 = maps[str1];
int ps2 = maps[str2];
int sp = querys(H[ps1],H[ps2]);
cout<<name[list[sp]]<<endl;
}
return 0;
}