【传送门:BZOJ2588】
简要题意:
给出n个节点的树,每个点有权值,有m个询问,每个询问输入x,y,k,求出x到y的路径上第k小的值
题解:
树上主席树,将根往下合并主席树
如果要得到x到y的主席树,就将rt[x]+rt[y]-rt[lca]-rt[fa[lca]]就能得到x到y的路径的信息了
参考代码:
#include<cstdio> #include<cstring> #include<algorithm> #include<cstdlib> #include<cmath> using namespace std; struct trnode { int lc,rc,c; }tr[3100000];int trlen,rt[110000]; int s[110000],ls[110000],n; int LS(int d) { int l=1,r=n,ans; while(l<=r) { int mid=(l+r)/2; if(ls[mid]<=d) { ans=mid; l=mid+1; } else r=mid-1; } return ans; } void Link(int &u,int l,int r,int p) { if(u==0) u=++trlen; tr[u].c++; if(l==r) return ; int mid=(l+r)/2; if(p<=mid) Link(tr[u].lc,l,mid,p); else Link(tr[u].rc,mid+1,r,p); } void Merge(int &u1,int u2) { if(u1==0){u1=u2;return ;} if(u2==0) return ; tr[u1].c+=tr[u2].c; Merge(tr[u1].lc,tr[u2].lc); Merge(tr[u1].rc,tr[u2].rc); } int findkth(int u1,int u2,int u3,int u4,int l,int r,int k) { if(l==r) return ls[l]; int c=tr[tr[u1].lc].c+tr[tr[u2].lc].c-tr[tr[u3].lc].c-tr[tr[u4].lc].c; int mid=(l+r)/2; if(c>=k) return findkth(tr[u1].lc,tr[u2].lc,tr[u3].lc,tr[u4].lc,l,mid,k); else return findkth(tr[u1].rc,tr[u2].rc,tr[u3].rc,tr[u4].rc,mid+1,r,k-c); } struct node { int x,y,next; }a[210000];int len,last[110000]; void ins(int x,int y) { len++; a[len].x=x;a[len].y=y; a[len].next=last[x];last[x]=len; } int dep[110000],bin[21],f[110000][21]; void dfs(int x) { for(int i=1;bin[i]<=dep[x];i++) { f[x][i]=f[f[x][i-1]][i-1]; } for(int k=last[x];k;k=a[k].next) { int y=a[k].y; if(y!=f[x][0]) { f[y][0]=x; dep[y]=dep[x]+1; Merge(rt[y],rt[x]); dfs(y); } } } int LCA(int x,int y) { if(dep[x]<dep[y]) swap(x,y); for(int i=20;i>=0;i--) { if(dep[x]-dep[y]>=bin[i]) { x=f[x][i]; } } if(x==y) return x; for(int i=20;i>=0;i--) { if(dep[x]>=bin[i]&&f[x][i]!=f[y][i]) { x=f[x][i];y=f[y][i]; } } return f[x][0]; } int main() { bin[0]=1;for(int i=1;i<=20;i++) bin[i]=bin[i-1]<<1; int m; scanf("%d%d",&n,&m); for(int i=1;i<=n;i++) scanf("%d",&s[i]),ls[i]=s[i]; sort(ls+1,ls+n+1); trlen=0;memset(rt,0,sizeof(rt)); for(int i=1;i<=n;i++) Link(rt[i],1,n,LS(s[i])); len=0;memset(last,0,sizeof(last)); for(int i=1;i<n;i++) { int x,y; scanf("%d%d",&x,&y); ins(x,y); ins(y,x); } dep[1]=1;f[1][0]=0;dfs(1); int lastans=0; for(int i=1;i<=m;i++) { int x,y,k; scanf("%d%d%d",&x,&y,&k); x^=lastans; int lca=LCA(x,y); lastans=findkth(rt[x],rt[y],rt[lca],rt[f[lca][0]],1,n,k); printf("%d ",lastans); } return 0; }