题目描述
您需要写一种数据结构(可参考题目标题),来维护一些数,其中需要提供以下操作:
1.插入x数
2.删除x数(若有多个相同的数,因只删除一个)
3.查询x数的排名(排名定义为比当前数小的数的个数+1。若有多个相同的数,因输出最小的排名)
4.查询排名为x的数
5.求x的前驱(前驱定义为小于x,且最大的数)
6.求x的后继(后继定义为大于x,且最小的数)
输入输出格式
输入格式:
第一行为n,表示操作的个数,下面n行每行有两个数opt和x,opt表示操作的序号(1≤opt≤6)
输出格式:
对于操作3,4,5,6每行输出一个数,表示对应答案
输入输出样例
输入样例#1:
10 1 106465 4 1 1 317721 1 460929 1 644985 1 84185 1 89851 6 81968 1 492737 5 493598
输出样例#1:
106465 84185 492737
说明
时空限制:1000ms,128M
1.n的数据范围: n≤100000
2.每个数的数据范围:[-107,107]
代码
变量定义
ch[N][2]
:二维数组,ch[x][0]
代表 x的左儿子,ch[x][1]
代表 x的右儿子。
val[N]
:一维数组,val[x]
代表x存储的值。
cnt[N]
:一维数组,cnt[x]
代表x存储的重复权值的个数。
par[N]
:一维数组,par[x]
代表x的父节点。
size[N]
:一维数组,size[x]
代表x子树下的储存的权值数(包括重复权值)。
基本操作
rotate
Splay使用旋转保持平衡。所以旋转是最重要的操作,也是最核心的操作。
splay
将一个节点一路rotate到指定节点的儿子。
find
将最大的小于等于x的数所在的节点splay到根。
#include<bits/stdc++.h> #define inf 0x3f3f3f3f using namespace std; const int maxn=100000+100; int ch[maxn][2],fa[maxn],val[maxn],size[maxn],cnt[maxn]; int root,ncnt; inline int read() { int x=0,f=1;char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();} while(ch>='0'&&ch<='9'){x=(x<<3)+(x<<1)+ch-'0';ch=getchar();} return x*f; } inline int chk(int u) { return ch[fa[u]][1]==u; } inline void pushup(int u) { size[u]=size[ch[u][1]]+size[ch[u][0]]+cnt[u]; } inline void rotate(int u) { int f=fa[u],ff=fa[f],k=chk(u),s=ch[u][k^1]; ch[f][k]=s,fa[s]=f; ch[ff][chk(f)]=u,fa[u]=ff; ch[u][k^1]=f,fa[f]=u; pushup(u),pushup(f); } inline void splay(int u,int goal=0) { while(fa[u]!=goal) { int f=fa[u],ff=fa[f]; if(ff!=goal) { if(chk(u)==chk(f))rotate(f); else rotate(u); } rotate(u); } if(!goal)root=u; } inline void insert(int x) { int u=root,f=0; while(u&&val[u]!=x) f=u,u=ch[u][x>val[u]]; if(u)cnt[u]++; else { u=++ncnt; if(f)ch[f][x>val[f]]=u; fa[u]=f,val[u]=x; size[u]=cnt[u]=1; ch[u][0]=ch[u][1]=0; } splay(u); } inline void find(int x) { int u=root; while(ch[u][x>val[u]]&&val[u]!=x) u=ch[u][x>val[u]]; splay(u); } inline int kth(int k) { int u=root; while(1) { if(ch[u][0]&&k<=size[ch[u][0]]) u=ch[u][0]; else if(k>size[ch[u][0]]+cnt[u]) k-=size[ch[u][0]]+cnt[u],u=ch[u][1]; else return u; } } inline int succ(int x) { find(x); if(val[root]>x)return root; int u=ch[root][1]; while(ch[u][0])u=ch[u][0]; return u; } inline int pre(int x) { find(x); if(val[root]<x)return root; int u=ch[root][0]; while(ch[u][1])u=ch[u][1]; return u; } void remove(int x) { int last=pre(x),next=succ(x); splay(last),splay(next,last); int u=ch[next][0]; if(cnt[u]>1)cnt[u]--,splay(u); else ch[next][0]=0; } int main() { insert(inf),insert(-inf); int n=read(); for(int i=1;i<=n;i++) { int op=read(),x=read(); if(op==1)insert(x); if(op==2)remove(x); if(op==3)find(x),printf("%d ",size[ch[root][0]]); if(op==4)printf("%d ",val[kth(x+1)]); if(op==5)printf("%d ",val[pre(x)]); if(op==6)printf("%d ",val[succ(x)]); } return 0; }