LCA求解的四种模板
树剖在线求解LCA
思想
树剖这里就不多解释了,求解LCA的过程就是轻重链的跳转,跟树剖求任意两点间的距离一样的操作,只不过不用线段树去维护(dis)了,那就直接上代码吧。
代码
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 1e6 + 10;
int head[N], to[N << 1], nex[N << 1], cnt = 1;
int sz[N], dep[N], fa[N], son[N], top[N];
int n, m;
inline int read() {
int f = 1, x = 0;
char c = getchar();
while(c > '9' || c < '0') {
if(c == '-') f = -1;
c = getchar();
}
while(c >= '0' && c <= '9') {
x = (x << 1) + (x << 3) + (c ^ 48);
c = getchar();
}
return f * x;
}
void add(int x, int y) {
to[cnt] = y;
nex[cnt] = head[x];
head[x] = cnt++;
}
void dfs1(int rt, int f) {
dep[rt] = dep[f] + 1;
sz[rt] = 1, fa[rt] = f;
for(int i = head[rt]; i; i = nex[i]) {
if(to[i] == f) continue;
dfs1(to[i], rt);
if(!son[rt] || sz[to[i]] > sz[son[rt]])
son[rt] = to[i];
sz[rt] += sz[to[i]];
}
}
void dfs2(int rt, int t) {
top[rt] = t;
if(!son[rt]) return ;
dfs2(son[rt], t);
for(int i = head[rt]; i; i = nex[i]) {
if(to[i] == fa[rt] || to[i] == son[rt]) continue;
dfs2(to[i], to[i]);
}
}
int solve(int x, int y) {
while(top[x] != top[y]) {
if(dep[top[x]] < dep[top[y]]) swap(x, y);
x = fa[top[x]];
}
return dep[x] < dep[y] ? x : y;
}
int main() {
// freopen("in.txt", "r", stdin);
// freopen("out.txt", "w", stdout);
n = read(), m = read();
int rt = read();
int x, y;
for(int i = 1; i < n; i++) {
x = read(), y = read();
add(x, y);
add(y, x);
}
dfs1(rt, 0);
dfs2(rt, rt);
for(int i = 1; i <= m; i++) {
x = read(), y = read();
printf("%d
", solve(x, y));
}
return 0;
}
Tarjan离线求解
思想
本质就是利用了dfs的节点顺序,当我们正在递归两个节点的最近公共祖先时,显然这两个点是属于其子树的节点,那么当我们第一次遍历完两个需要求解的两个点时,其最近的尚未被完全遍历完子节点的节点就是他们两个的最近公共祖先。
代码
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 5e5 + 10;
int head[N], to[N << 1], nex[N << 1], cnt = 1;
int visit[N], fa[N], n, m;
int qhead[N], qto[N << 1], qnex[N << 1], qcnt = 1, qid[N << 1], ans[N];
inline int read() {
int f = 1, x = 0;
char c = getchar();
while(c > '9' || c < '0') {
if(c == '-') f = -1;
c = getchar();
}
while(c >= '0' && c <= '9') {
x = (x << 1) + (x << 3) + (c ^ 48);
c = getchar();
}
return f * x;
}
void add_edge(int x, int y) {
to[cnt] = y;
nex[cnt] = head[x];
head[x] = cnt++;
}
void add_query(int x, int y, int w) {
qto[qcnt] = y;
qnex[qcnt] = qhead[x];
qid[qcnt] = w;
qhead[x] = qcnt++;
}
int find(int rt) {
return rt == fa[rt] ? rt : fa[rt] = find(fa[rt]);
}
void tarjan(int rt, int f) {
for(int i = head[rt]; i; i = nex[i]) {
if(to[i] == f) continue;
tarjan(to[i], rt);
fa[to[i]] = rt;
}
visit[rt] = 1;
for(int i = qhead[rt]; i; i = qnex[i]) {
if(!visit[qto[i]]) continue;
ans[qid[i]] = find(qto[i]);
}
}
int main() {
// freopen("in.txt", "r", stdin);
n = read(), m = read();
int rt = read();
for(int i = 1; i < n; i++) {
int x = read(), y = read();
add_edge(x, y);
add_edge(y, x);
}
for(int i = 1; i <= n; i++) fa[i] = i;
for(int i = 1; i <= m; i++) {
int x = read(), y = read();
add_query(x, y, i);
add_query(y, x, i);
}
tarjan(rt, 0);
for(int i = 1; i <= m; i++)
printf("%d
", ans[i]);
return 0;
}
ST表 + RMQ在线求解
思想
利用dfs的遍历,在遍历两个点的时候,一定会在中间返回到其最近公共祖先,这个时候的公共祖先也就是这两个点的遍历中的最小值。
代码
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
inline ll read() {
ll f = 1, x = 0;
char c = getchar();
while(c > '9' || c < '0') {
if(c == '-') f = -1;
c = getchar();
}
while(c >= '0' && c <= '9') {
x = (x << 1) + (x << 3) + (c ^ 48);
c = getchar();
}
return f * x;
}
const int N = 5e5 + 10;
int head[N], to[N << 1], nex[N << 1], cnt = 1;
int id[N], tot, last;
int st[N << 2][30];
void add(int x, int y) {
to[cnt] = y;
nex[cnt] = head[x];
head[x] = cnt++;
}
void dfs(int rt, int fa) {
id[rt] = last = ++tot;
st[tot][0] = rt;
for(int i = head[rt]; i; i = nex[i]) {
if(to[i] == fa) continue;
dfs(to[i], rt);
st[++tot][0] = rt;
}
}
int MIN(int a, int b) {
return id[a] < id[b] ? a : b;
}
int main() {
// freopen("in.txt", "r", stdin);
int n = read(), m = read(), rt = read();
for(int i = 1; i < n; i++) {
int x = read(), y = read();
add(x, y);
add(y, x);
}
dfs(rt, 0);
int k = log(last) / log(2);
for(int j = 1; j <= k; j++)
for(int i = 1; i + (1 << j) - 1 <= last; i++)
st[i][j] = MIN(st[i][j - 1], st[i + (1 << (j - 1))][j - 1]);
for(int i = 1; i <= m; i++) {
int x = read(), y = read();
x = id[x], y = id[y];
if(x > y) swap(x, y);
int k = log(y - x + 1) / log(2);
printf("%d
", MIN(st[x][k], st[y - (1 << k) + 1][k]));
}
return 0;
}
倍增
思想
类似于快速幂,通过二进制数的组合来达到(log_2)级别的优化,但是需要注意其中进制的枚举大小顺序。
代码
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
inline ll read() {
ll f = 1, x = 0;
char c = getchar();
while(c > '9' || c < '0') {
if(c == '-') f = -1;
c = getchar();
}
while(c >= '0' && c <= '9') {
x = (x << 1) + (x << 3) + (c ^ 48);
c = getchar();
}
return f * x;
}
const int N = 5e5 + 10;
int head[N], to[N << 1], nex[N << 1], cnt = 1;
int fa[N][21], dep[N], n, m;
void add(int x, int y) {
to[cnt] = y;
nex[cnt] = head[x];
head[x] = cnt++;
}
void dfs(int rt, int f) {
dep[rt] = dep[f] + 1;
fa[rt][0] = f;
for(int i = 1; 1 << i <= dep[rt]; i++)//进制由小到大递推
fa[rt][i] = fa[fa[rt][i - 1]][i - 1];
for(int i = head[rt]; i; i = nex[i]) {
if(to[i] == f) continue;
dfs(to[i], rt);
}
}
int LCA(int x, int y) {
if(dep[x] < dep[y]) swap(x, y);
for(int i = 20; i >= 0; i--)//进制由大到小开始组合,
if(dep[fa[x][i]] >= dep[y])
x = fa[x][i];
if(x == y) return x;//注意特判
for(int i = 20; i >= 0; i--)//进制从小到大开始组合,
if(fa[x][i] != fa[y][i])
x = fa[x][i], y = fa[y][i];
return fa[x][0];//这一步尤其考虑,为什么x, y不知LCA,而其父节点就一定是LCA,
}
int main() {
// freopen("in.txt", "r", stdin);
int n = read(), m = read(), rt = read();
for(int i = 1; i < n; i++) {
int x = read(), y = read();
add(x, y);
add(y, x);
}
dfs(rt, 0);
for(int i = 1; i <= m; i++) {
int x = read(), y = read();
printf("%d
", LCA(x, y));
}
return 0;
}