https://vjudge.net/problem/SPOJ-COT
思路:我们对于每个节点都在它父亲上建主席树,因为具有前缀和性质,我们在求树上两点时u->v,它这个区间的值是T[T[y].l].sum+T[T[x].l].sum-T[T[lca].l].sum-T[T[falca].l].sum(画图即可,u点和v点都是一颗从根过来的主席树,lca以上都有重复多余的路径)
类似的题有:https://nanti.jisuanke.com/t/38229 树上路径不大于k的个数 注意的是 题目是边权
https://vjudge.net/problem/HDU-4417 区间小于等于k的个数
#include<bits/stdc++.h> using namespace std; #define ll long long #define lson l,m,rt<<1 #define rson m+1,r,rt<<1|1 const int maxn = 100000+5; struct node{ int l,r,sum; }T[maxn*40]; int root[maxn],a[maxn]; vector<int> v; int head[maxn],Next[maxn<<1],To[maxn<<1],fa[maxn],id[maxn],top[maxn],tp,pos[maxn]; int tot,cnt,cnnt,size[maxn],son[maxn],deep[maxn],n,len; void add(int u,int v){ Next[++cnnt]=head[u]; head[u]=cnnt; To[cnnt]=v; } int getid(int x){ return lower_bound(v.begin(),v.end(),x)-v.begin()+1; } void update(int l,int r,int &x,int y,int L){ T[++cnt]=T[y],T[cnt].sum++,x=cnt; if(l==r) return; int m=(l+r)>>1; if(L<=m) update(l,m,T[x].l,T[y].l,L); else update(m+1,r,T[x].r,T[y].r,L); } int query(int l,int r,int x,int y,int lca,int falca,int pos){ if(l==r) return l; int m=(l+r)>>1; int sum=T[T[y].l].sum+T[T[x].l].sum-T[T[lca].l].sum-T[T[falca].l].sum; if(sum>=pos) return query(l,m,T[x].l,T[y].l,T[lca].l,T[falca].l,pos); else return query(m+1,r,T[x].r,T[y].r,T[lca].r,T[falca].r,pos-sum); } void dfs1(int u,int f,int dep){ fa[u]=f; deep[u]=dep+1; int mx=0; for(int i=head[u]; i!=-1; i=Next[i]){ int v=To[i]; if(v==f) continue; dfs1(v,u,dep+1); size[u]+=size[v]; if(size[v]>mx) mx=size[v],son[u]=v; } size[u]++; } void dfs2(int u,int tp){ id[u]=++tot; pos[tot]=u; top[u]=tp; update(1,len,root[u],root[fa[u]],getid(a[u])); if(son[u]) dfs2(son[u],tp); for(int i=head[u]; i!=-1; i=Next[i]){ int v=To[i]; if(v!=son[u]&&v!=fa[u]) dfs2(v,v); } } int lca(int u,int v){ while(top[u]!=top[v]){ if(deep[top[u]]<deep[top[v]]) swap(u,v); u=fa[top[u]]; } return deep[u]<deep[v]?u:v; } int main(){ int n,m; scanf("%d %d",&n,&m); for(int i=1;i<=n;i++){ scanf("%d",&a[i]); v.push_back(a[i]); } sort(v.begin(),v.end()); v.erase(unique(v.begin(),v.end()),v.end()); len=v.size(); memset(head,-1,sizeof(head)); for(int i=1;i<n;i++){ int u,v; scanf("%d %d",&u,&v); add(u,v); add(v,u); } dfs1(1,0,0); dfs2(1,1); for(int i=1;i<=m;i++){ int l,r,k; scanf("%d %d %d",&l,&r,&k); int Lca=lca(l,r); printf("%d ",v[query(1,len,root[l],root[r],root[Lca],root[fa[Lca]],k)-1]); } return 0; }