我天,这竟然是道最近公共祖先的板子题。代码里利用了我的LCA模板,前向星存图。
#define _CRT_SECURE_NO_WARNINGS #include <cstdio> #include <cstring> #include <algorithm> #include <iostream> using namespace std; const int N = 1e5 + 7, H = 20; int n, tot; //树的根节点 int head_copy[N], head[N]; //存的边的信息,head_copy[i]表示第i个节点的头指针 int to[N << 1]; //第i条边指向的节点 int nxt[N << 1]; //第i条边的下一个指针 int anc[N][H]; //对于每一个节点v,记录anc[v][k],表示它向上走pow(2,k)步之后到达的节点 int Stack[N], dep[N]; void dfs(int root) { int top = 0; dep[root] = 1; for (int i = 0; i < H; ++i) anc[root][i] = root; Stack[++top] = root; //先求出anc[v][0] memcpy(head_copy, head, sizeof head);//head为原始的第i个节点的头指针 while (top) { int x = Stack[top]; if (x != root) { for (int i = 1; i < H; ++i) { //再求出其他anc[v][k] int y = anc[x][i - 1]; anc[x][i] = anc[y][i - 1]; } } for (int &i = head_copy[x]; ~i; i = nxt[i]) { //这里i为引用,会修改head_copy int y = to[i]; if (y != anc[x][0]) { dep[y] = dep[x] + 1; anc[y][0] = x; Stack[++top] = y; } } while (top && head_copy[Stack[top]] == -1) //==-1,这个-1和初始head有关 top--; } } inline void swim(int &x, int k) { //从节点x向上移动k步,并将x赋为新走到的节点 for (int i = 0; k > 0; ++i) { if (k & 1) x = anc[x][i]; k /= 2; } } int lca(int x, int y) { //寻找x, y的LCA。 int k; if (dep[x] > dep[y]) swap(x, y); swim(y, dep[y] - dep[x]); //首先利用swim将x,y调整到同一高度 if (x == y) return x; //若x和y重合,就是我们要找的LCA while (true) { //否则,不断第寻找一个最小的k,使得anc[x][k] = anc[y][k] for (k = 0; anc[x][k] != anc[y][k]; ++k); if (k == 0) return anc[x][0]; x = anc[x][k - 1];//新的x,y和原来的x,y有相同的LCA y = anc[y][k - 1]; } return -1; } void init() { tot = 0; memset(head, -1, sizeof head); } void add(int u, int v) { to[tot] = v; nxt[tot] = head[u]; head[u] = tot++; } int dist(int u, int v) { int x = lca(u, v); int res = dep[u] + dep[v] - 2 * dep[x]; return res; } int main() { ios::sync_with_stdio(false); int q, pi; while (cin >> n >> q) { init(); for (int i = 2; i <= n; ++i) { cin >> pi; add(pi, i); add(i, pi); } dfs(1); int a, b, c, m1, m2, m3, m; while (q--) { cin >> a >> b >> c; m1 = lca(a, b), m2 = lca(a, c), m3 = lca(b, c); if (m1 == m2) m = m3; else if (m1 == m3) m = m2; else m = m1; cout << 1 + max(dist(m, a), max(dist(m, b), dist(m, c))) << endl; } } return 0; }
by myorange