借这道题练一下Treap和Splay的板子。
code:
#include <cstdio> #include <cstdlib> using namespace std; int read() { char c;while(c=getchar(),(c<'0'||c>'9')&&c!='-'); int x=0,y=1;c=='-'?y=-1:x=c-'0'; while(c=getchar(),c>='0'&&c<='9')x=x*10+c-'0'; return x*y; } int N,M,a[200005],b[200005]; int root=0,tr[200005][2],v[200005],f[200005],rd[200005]; int cnt=0; void rotate(int &x,int o) { int k=tr[x][o]; tr[x][o]=tr[k][o^1]; tr[k][o^1]=x; f[k]=f[x]; f[x]=f[tr[x][0]]+f[tr[x][1]]+1; x=k; } void insert(int &x,int val) { if(!x){ x=++cnt; v[x]=val; f[x]++; rd[x]=rand(); return ; } f[x]++; int to=val>v[x]; insert(tr[x][to],val); if(rd[tr[x][to]]>rd[x])rotate(x,to); return ; } int Query(int x,int kth) { if(!x)return -1; if(f[tr[x][0]]>=kth)return Query(tr[x][0],kth); if(f[tr[x][0]]+1<kth)return Query(tr[x][1],kth-f[tr[x][0]]-1); return v[x]; } int main() { srand(23333); N=read(),M=read(); register int i,j; for(i=1;i<=N;i++)a[i]=read(); for(i=1;i<=M;i++)b[i]=read(); j=1; for(i=1;i<=N;i++){ insert(root,a[i]); while(i==b[j]) printf("%d ",Query(root,j)),j++; } }
#include <cstdio> using namespace std; int root,cnt; int tr[200005][2],f[200005],v[200005],fa[200005]; int get(int x){return x==tr[fa[x]][1];} void up(int x){f[x]=f[tr[x][0]]+f[tr[x][1]]+1;} int rotate(int x) { int ol=fa[x],olol=fa[ol],to=get(x); tr[ol][to]=tr[x][to^1],fa[tr[x][to^1]]=ol; tr[x][to^1]=ol;fa[ol]=x; fa[x]=olol; if(olol) tr[olol][ol==tr[olol][1]]=x; up(ol);up(x); } void splay(int x) { for(int S;S=fa[x];rotate(x)) if(fa[S]) rotate(get(x)==get(S)?S:x); root=x; } int dist; void insert(int &x,int val,int pos) { if(!x){ x=++cnt; v[x]=val; fa[x]=pos; f[x]++; dist=cnt; return ; } int to=val>v[x]; insert(tr[x][to],val,x); up(x); return ; } int Query(int x,int kth) { if(!x)return -1; if(f[tr[x][0]]>=kth)return Query(tr[x][0],kth); if(f[tr[x][0]]+1<kth)return Query(tr[x][1],kth-f[tr[x][0]]-1); return v[x]; } int N,M,a[200005],b[200005]; int read() { char c;while(c=getchar(),(c<'0'||c>'9')&&c!='-'); int x=0,y=1;c=='-'?y=-1:x=c-'0'; while(c=getchar(),c>='0'&&c<='9')x=x*10+c-'0'; return x*y; } int main() { N=read(),M=read(); register int i,j; for(i=1;i<=N;i++)a[i]=read(); for(i=1;i<=M;i++)b[i]=read(); j=1; for(i=1;i<=N;i++){ insert(root,a[i],0);splay(dist); while(i==b[j])printf("%d ",Query(root,j)),j++; } return 0; }