题意:你被给予了N个节点的树。树的节点编号从1到N。每个节点都有权值。会有如下的操作:
u v k:询问从路径u到路径v的第k小的权值
分析:对于可持久化线段树来说,每一棵线段树都维护着序列的前缀。那么转换到树上,就是维护从根节点到当前节点的前缀,因此,对于路径u,v来说,路径u,v上的第k小的权值是root[u] + root[v] -root[lca(u, v)] - root[pa[lca(u, v)]]。root[u] + root[v]让lca(u, v)多加了两次,要减去一次-root[lca(u, v)]。
如下图所示:
root[u]表示从根到u的前缀,root[v]表示从根到v的前缀,root[lca(u, v)]表示从根到lca(u, v)的前缀。
#include <iostream>
#include <cstdio>
#include <cstring>
#include <vector>
#include <algorithm>
using namespace std;
const int N = 100005;
int n, m;
int w[N];
vector<int> nums;
struct ST
{
int lc, rc;
int cnt;
}tr[4 * N + 20 * N];
int fa[N][20];
int dep[N];
int root[N], tot;
int h[N], e[N * 2], ne[N * 2], idx;
void add(int a, int b)
{
e[idx] = b, ne[idx] = h[a], h[a] = idx++;
}
void pushup(int u)
{
tr[u].cnt = tr[tr[u].lc].cnt + tr[tr[u].rc].cnt;
}
int insert(int p, int l, int r, int x)
{
int q = ++tot;
tr[q] = tr[p];
if (l == r)
{
++tr[q].cnt;
return q;
}
int mid = l + r >> 1;
if (x <= mid) tr[q].lc = insert(tr[p].lc, l, mid, x);
else tr[q].rc = insert(tr[p].rc, mid + 1, r, x);
tr[q].cnt = tr[tr[q].lc].cnt + tr[tr[q].rc].cnt;
return q;
}
//pp:最近公共祖先的父节点
int query(int u, int v, int p, int pp, int l, int r, int k)
{
if (l == r) return l;
int mid = l + r >> 1;
int num = tr[tr[u].lc].cnt + tr[tr[v].lc].cnt - tr[tr[p].lc].cnt - tr[tr[pp].lc].cnt;
if (k <= num) return query(tr[u].lc, tr[v].lc, tr[p].lc, tr[pp].lc, l, mid, k);
else return query(tr[u].rc, tr[v].rc, tr[p].rc, tr[pp].rc, mid + 1, r, k - num);
}
void dfs(int u, int father)
{
fa[u][0] = father;
root[u] = insert(root[fa[u][0]], 0, nums.size() - 1, w[u]);
dep[u] = dep[father] + 1;
for (int i = 1; i <= 19; ++i) fa[u][i] = fa[fa[u][i - 1]][i - 1];
for (int i = h[u]; i != -1; i = ne[i])
{
int j = e[i];
if (j == father) continue;
dfs(j, u);
}
}
int lca(int a, int b)
{
if (dep[a] < dep[b]) swap(a, b);
for (int k = 19; k >= 0; --k)
{
if (dep[fa[a][k]] >= dep[b])
{
a = fa[a][k];
}
}
if (a == b) return a;
for (int k = 19; k >= 0; --k)
{
if (fa[a][k] != fa[b][k])
{
a = fa[a][k];
b = fa[b][k];
}
}
return fa[a][0];
}
int main()
{
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; ++i)
{
scanf("%d", &w[i]);
nums.push_back(w[i]);//离散化值域
}
sort(nums.begin(), nums.end());
nums.erase(unique(nums.begin(), nums.end()), nums.end());
for (int i = 1; i <= n; ++i)
{
w[i] = lower_bound(nums.begin(), nums.end(), w[i]) - nums.begin();//存储离散化的位置
}
int u, v;
memset(h, -1, sizeof h);
for (int i = 1; i < n; ++i)
{
scanf("%d%d", &u, &v);
add(u, v), add(v, u);
}
dfs(1, 0);//插入
int k;
while (m--)
{
scanf("%d%d%d", &u, &v, &k);
int p = lca(u, v);
printf("%d
", nums[query(root[u], root[v], root[p], root[fa[p][0]], 0, nums.size() - 1, k)]);
}
return 0;
}