处理了一年的边界问题
#include <bits/stdc++.h> using namespace std; const int N = 2e5 + 10; #define fi first #define se second #define sz(v) ((int)(v).size()) #define debug(a) cout << #a << " = " << a << endl; typedef long long ll; typedef pair <int, int> P; const ll MOD = 1e9 + 7; string arr[N]; map < string, vector < int > > mp; priority_queue < int, vector < int >, greater < int > > Q; struct node { int l, r, k; } tree[N]; int d[N], f[N], sz[N], hs[N]; void dfs1(int u, int fa, int dep) { if (!u) return ; d[u] = dep; f[u] = fa; sz[u] = 1; dfs1(tree[u].l, u, dep + 1); dfs1(tree[u].r, u, dep + 1); sz[u] += sz[tree[u].l] + sz[tree[u].r]; if (tree[u].l == tree[u].r && tree[u].l == 0) return ; hs[u] = sz[tree[u].l] > sz[tree[u].r] ? tree[u].l : tree[u].r; } int top[N]; void dfs2(int u, int T) { if (!u) return ; top[u] = T; if (!hs[u]) return ; dfs2(hs[u], T); if (hs[u] != tree[u].l) dfs2(tree[u].l, tree[u].l); else dfs2(tree[u].r, tree[u].r); } int Lca(int u, int v) { while (top[u] != top[v]) { if (d[top[u]] < d[top[v]]) swap(u, v); u = f[top[u]]; } return d[u] > d[v] ? v : u; } vector < P > len; int main() { ios::sync_with_stdio(false); cin.tie(0); int n; cin >> n; for (int i = 1; i <= n; i++) { cin >> tree[i].l >> tree[i].r >> tree[i].k; len.emplace_back(tree[i].k, i); } int q; cin >> q; sort(len.begin(), len.end()); dfs1(1, 0, 0); dfs2(1, -1); while (q--) { int L, R, l, r; cin >> L >> R; if ((L <= len.front().fi && R >= len.back().fi) || R < len.front().fi || L > len.back().fi) { cout << "1" << endl; continue; } auto pos1 = lower_bound(len.begin(), len.end(), P(L, 0)); if (pos1 == len.end()) pos1--, l = (*pos1).se; else if (pos1 != len.begin()) { int x1 = (*pos1).se; pos1--; int x2 = (*pos1).se; l = Lca(x1, x2); } else l = (*pos1).se; auto pos2 = upper_bound(len.begin(), len.end(), P(R, 10000000)); if (pos2 == len.begin()) r = (*pos2).se; else if (pos2 != len.end()) { int x1 = (*pos2).se; pos2--; int x2 = (*pos2).se; r = Lca(x1, x2); } else pos2--, r = (*pos2).se; int lca = Lca(l, r); if (L <= len.front().fi) { cout << d[r] * 2 + 3 << endl; continue; } if (R >= len.back().fi) { cout << d[l] * 2 + 3 << endl; continue; } int ans = d[lca] * 2 + 1; ans += (d[l] - d[lca]) * 2 + (d[r] - d[lca]) * 2 + 2; cout << ans << endl; } return 0; }