LCA:树上最近公共祖先
目前掌握的算法有两种(都是在线做法)
先说简单的:倍增
这里写代码片
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#include<cmath>
using namespace std;
struct node{
int u,v,next;
};
node way[1000005]; //记录边,用数组模拟的邻接表
int h,w,heap[500005];
int n,m,s,st[500005],a1,a2;
int f[500005][20],tot=0; //f[i][j]表示从i向上跳2^j步到达的节点
int deep[500005],mj;
bool p[500005];
int add(int x,int y) //加边
{
tot++;
way[tot].u=x;
way[tot].v=y;
way[tot].next=st[x];
st[x]=tot;
}
int bfs() //广搜,记录每个节点在树中的深度
{ //手写的队列,绝对良心
int i,j;
h=0;
w=1;
heap[w]=s;
deep[s]=1;
p[s]=0;
do
{
h++;
int r=heap[h];
for (i=st[r];i;i=way[i].next)
{
if (way[i].v!=f[r][0]&&p[way[i].v])
{
w++;
heap[w]=way[i].v;
p[way[i].v]=0;
deep[way[i].v]=deep[r]+1;
f[way[i].v][0]=r; //这个节点向上跳一步就是r(它的爸爸)
}
}
}
while (h<w);
}
int lca(int x,int y) //算法主体lca
{
int i,j;
if (deep[x]<deep[y])
swap(x,y);
int d=deep[x]-deep[y]; //两个查找节点的深度差
if (d) //这里的写法要特别注意,一不注意就会出现玄学的问题
for (i=0;i<=mj&&d;i++,d>>=1)
if (d&1)
x=f[x][i];
if (x==y) return x; //已经跳到一起了
for (i=mj;i>=0;i--)
if (f[x][i]!=f[y][i])
{
x=f[x][i];
y=f[y][i];
}
return (f[x][0]); //直到跳到只有一步就重合了 ,返回
}
int main()
{
memset(p,1,sizeof(p));
scanf("%d%d%d",&n,&m,&s);
int x,y;
for (int i=1;i<n;i++)
{
scanf("%d%d",&x,&y);
add(x,y); //这里一定要双向加边
add(y,x);
}
deep[s]=1;
bfs();
mj=(int)(log(n)/log(2))+1; //为避免出错,所以在这里+1
for (int i=1;i<=mj;i++) //lca的预处理
for (int j=1;j<=n;j++)
f[j][i]=f[f[j][i-1]][i-1];
//记录j向上跳2^i的爸爸
for (int i=1;i<=m;i++)
{
scanf("%d%d",&a1,&a2);
printf("%d
",lca(a1,a2));
}
}
重点还是st表的算法:
这个算法是基于RMQ(区间最大最小值编号)的,不懂的可以这里学习一些
而求LCA就是把树通过深搜得到一个序列,然后转化为求区间的最小编号
这里写代码片
/*https://wenku.baidu.com/view/78ceaf54ad02de80d4d8408d.html?from=search*/
#include<cstdio>
#include<cstring>
#include<iostream>
#include<cmath>
using namespace std;
const int N=100001;
struct node{
int x,y,nxt;
};
node way[N];
int st[N],tot=0;
int df[(N<<1)+5],top=0,deep[(N<<1)+5],first[N]; //df:dfs序,deep深度,first第一次出现的位置
int f[N][20];
int len,n,m,root; //序列总长度
void add(int u,int w)
{
way[++tot].x=u; way[tot].y=w;
way[tot].nxt=st[u]; st[u]=tot;
}
void RMQ()
{
int i,j;
for (i=1;i<=len;i++) f[i][0]=i;
//f数组里记录的是这个区间内最小值在dfs序中的下标
i=1,j=1;
for (j=1;i+(1<<j)-1<=len;j++)
for (i=1;i+(1<<j)-1<=len;i++) //我们要找的是deep最小的
if (deep[f[i][j-1]]<deep[f[i+(1<<j)][j-1]])
f[i][j]=f[i][j-1];
else
f[i][j]=f[i+(1<<j)][j-1];
}
void dfs(int now,int fa,int dep)
{
df[++top]=now;
deep[now]=dep;
if (!first[now])
first[now]=top; //first 出现
for (int i=st[now];i;i=way[i].nxt)
if (way[i].y!=fa) //
{
dfs(way[i].y,now,dep+1);
df[++top]=now;
deep[top]=dep;
}
}
int ask(int l,int r)
{
int ln=r-l+1;
int d=(int)log(ln)/log(2); //区间长度需要多长覆盖
d--; //防止超出区间
int a=deep[f[l][d]];
int b=deep[f[r-(1<<d)+1][d]]; //深度
if (a<b)
return df[f[l][d]];
else return df[f[r-(1<<d)+1][d]];
}
int lca(int x,int y)
{
if (first[x]>first[y]) //first要从小到大
return ask(first[x],first[y]);
}
int main()
{
scanf("%d%d%d",&n,&m,&root);
for (int i=1;i<n;i++)
{
int x,y;
scanf("%d%d",&x,&y);
add(x,y); //双向加边
add(y,x);
}
dfs(root,-1,0);
len=2*n-1;
RMQ();
for (int i=1;i<=m;i++)
{
int x,y;
scanf("%d%d",&x,&y);
printf("%d
",lca(x,y));
}
return 0;
}