算法原理
Treap
一种好用的数据结构,支持插入((insert)),删除((remove)),查前驱((pre))后继((suf)),查树的排名((get rank by val)),据排名查数((getvalbyrank))。
前置知识:BST二叉查找树
二叉树,点带权,左子树上的点都比根小,右子树上的点都比根大。
但是,如果插入的是一个单调的序列,每次对树进行检索操作都是(O(n)) ,总复杂度变成(O(n^2)) ,不能接受。
随机权值&左旋右旋
这是Treap的核心操作。通过对原(BST)树进行旋转操作,使树的高度减小,形状更平衡 。
那么,什么样的操作才是合理的旋转操作呢?由于普通的(BST)在随机的数据下是趋近平衡的,所以我们给每个点一个随机权值,满足大根堆性质,尽量平衡数。
右旋zig
把(p)的左子节点绕着(p)向右旋转。
inline void zig(int &p){
int q=t[p].l; //q为原左儿子
t[p].l=t[q].r;//左儿子的右儿子变成左儿子
t[q].r=p;
p=q;
return;
}
左旋同理。
代码详解
Build
初始状态设为一个(INF) ,一个(-INF) ,防止溢出。
根节点编号设为(1) ,初始权值为(INF)
记得先新建(-INF) ,因为我是初始化根和根的右节点
我就是因为忘了(build)才多调了半小时
inline void build(){
New(-INF),New(INF);
rt=1,t[1].r=2;
update(rt);
return;
}
Get Rank By Value
inline int getrank(int p,int val){
if(!p) return 0;
if(val==t[p].val) return t[t[p].l].sz+1;//找到了
if(val<t[p].val) return getrank(t[p].l,val);//往左子树找
if(val>t[p].val) return getrank(t[p].r,val)+t[p].cnt+t[t[p].l].sz;// 往右子树找
}
Get Value By Rank
inline int getval(int p,int rank){
if(!p) return INF;
if(t[t[p].l].sz>=rank) return getval(t[p].l,rank);// 左子树的大小已经大于等于了,那这个排名的数肯定在左子树
if(t[t[p].l].sz+t[p].cnt>=rank) return t[p].val;
// 左子树的大小+父节点的大小才大于等于这个排名,那就是它了
return getval(t[p].r,rank-t[t[p].l].sz-t[p].cnt);
// 如果还小了,那就是在右子树上找排名为(rank-左子树大小-父节点大小)的数
}
Insert
如果以前没有这个点-->新建
如果以前有这个值了-->(cnt++)
inline void insert(int &p,int val){
if(!p){
p=New(val);
return;
}
if(val==t[p].val){
t[p].cnt++;
update(p);
return;
}
if(val<t[p].val){
insert(t[p].l,val);
if(t[p].data<t[t[p].l].data) zig(p);
}
// 找值,找对地方后进行旋转
if(val>t[p].val){
insert(t[p].r,val);
if(t[p].data<t[t[p].r].data) zag(p);
}
update(p);
return;
}
Remove
inline void remove(int &p,int val){
if(!p) return;
if(val==t[p].val){
if(t[p].cnt>1){
t[p].cnt--,update(p);
return;
//找到了
}
if(t[p].l||t[p].r){
if(!t[p].r||t[t[p].l].data>t[t[p].r].data) zig(p),remove(t[p].r,val);
else zag(p),remove(t[p].l,val);
update(p);
}
else p=0;
return;
}
val<t[p].val?remove(t[p].l,val):remove(t[p].r,val);
//递归查找要删的数
update(p);
return;
}
Pre&Suf
(pre):左子树上一直往右找
(suf):右子树上一直往左找
注意边界条件。
Code
#include<bits/stdc++.h>
#define N (400010)
#define INF (998244353)
using namespace std;
struct xbk{
int l,r,data,val,cnt,sz;
}t[N];
int n,tot,rt;
inline int read(){
int w=0;
bool f=0;
char ch=getchar();
while(ch>'9'||ch<'0'){
if(ch=='-') f=1;
ch=getchar();
}
while(ch>='0'&&ch<='9'){
w=(w<<3)+(w<<1)+(ch^48);
ch=getchar();
}
return f?-w:w;
}
inline int New(int val){
t[++tot].val=val;
t[tot].data=rand();
t[tot].sz=t[tot].cnt=1;
return tot;
}
inline void update(int p){
t[p].sz=t[t[p].l].sz+t[t[p].r].sz+t[p].cnt;
}
inline void build(){
New(-INF),New(INF);
rt=1,t[1].r=2;
update(rt);
return;
}
inline int getrank(int p,int val){
if(!p) return 0;
if(val==t[p].val) return t[t[p].l].sz+1;
if(val<t[p].val) return getrank(t[p].l,val);
if(val>t[p].val) return getrank(t[p].r,val)+t[p].cnt+t[t[p].l].sz;
}
inline int getval(int p,int rank){
if(!p) return INF;
if(t[t[p].l].sz>=rank) return getval(t[p].l,rank);
if(t[t[p].l].sz+t[p].cnt>=rank) return t[p].val;
return getval(t[p].r,rank-t[t[p].l].sz-t[p].cnt);
}
inline void zig(int &p){
int q=t[p].l;
t[p].l=t[q].r,t[q].r=p,p=q;
update(p),update(t[p].r);
}
inline void zag(int &p){
int q=t[p].r;
t[p].r=t[q].l,t[q].l=p,p=q;
update(p),update(t[p].l);
}
inline void insert(int &p,int val){
if(!p){
p=New(val);
return;
}
if(val==t[p].val){
t[p].cnt++;
update(p);
return;
}
if(val<t[p].val){
insert(t[p].l,val);
if(t[p].data<t[t[p].l].data) zig(p);
}
if(val>t[p].val){
insert(t[p].r,val);
if(t[p].data<t[t[p].r].data) zag(p);
}
update(p);
return;
}
inline int getpre(int val){
int ans=1;
int p=rt;
while(p){
if(val==t[p].val){
if(t[p].l){
p=t[p].l;
while(t[p].r) p=t[p].r;
ans=p;
}
break;
}
if(t[p].val<val&&t[p].val>t[ans].val) ans=p;
p=val<t[p].val?t[p].l:t[p].r;
}
return t[ans].val;
}
inline int getsuf(int val){
int ans=2;
int p=rt;
while(p){
if(val==t[p].val){
if(t[p].r){
p=t[p].r;
while(t[p].l) p=t[p].l;
ans=p;
}
break;
}
if(t[p].val>val&&t[p].val<t[ans].val) ans=p;
p=val<t[p].val?t[p].l:t[p].r;
}
return t[ans].val;
}
inline void remove(int &p,int val){
if(!p) return;
if(val==t[p].val){
if(t[p].cnt>1){
t[p].cnt--,update(p);
return;
}
if(t[p].l||t[p].r){
if(!t[p].r||t[t[p].l].data>t[t[p].r].data) zig(p),remove(t[p].r,val);
else zag(p),remove(t[p].l,val);
update(p);
}
else p=0;
return;
}
val<t[p].val?remove(t[p].l,val):remove(t[p].r,val);
update(p);
return;
}
int main(){
build();
n=read();
while(n--){
int opt=read(),val=read();
if(opt==1) insert(rt,val);
if(opt==2) remove(rt,val);
if(opt==3) printf("%d
",getrank(rt,val)-1);
if(opt==4) printf("%d
",getval(rt,val+1));
if(opt==5) printf("%d
",getpre(val));
if(opt==6) printf("%d
",getsuf(val));
}
return 0;
}