(模板)普通平衡树
题目描述
您需要写一种数据结构(可参考题目标题),来维护一些数,其中需要提供以下操作:
- 插入(x)数
- 删除(x)数(若有多个相同的数,因只删除一个)
- 查询(x)数的排名(排名定义为比当前数小的数的个数(+1))
- 查询排名为(x)的数
- 求(x)的前驱(前驱定义为小于(x),且最大的数)
- 求(x)的后继(后继定义为大于(x),且最小的数)
输入格式
第一行为(n),表示操作的个数,下面(n)行每行有两个数( ext{opt})和(x),( ext{opt})表示操作的序号((1 leq ext{opt} leq 6))
输出格式
对于操作(3,4,5,6)每行输出一个数,表示对应答案
样例输入
10 1 106465 4 1 1 317721 1 460929 1 644985 1 84185 1 89851 6 81968 1 492737 5 493598
样例输出
106465 84185 492737
说明/提示
【数据范围】
对于(100\%)的数据,(1 le n le 10^5),(|x| le 10^7)
题解
平衡树板题,(Splay)和红黑树都能做(当然我这个菜鸡只会(Splay)了)
首先介绍一下平衡树需要满足的条件:
- 是一个二叉树
- 对于任意一个点,它左子树的所有节点的值都小于这个点的值,它右子树的所有节点的值都大于这个点的值(当然你想左右互换也没人拦你)
那么我们就可以发现这样的性质:
- 这棵二叉树的中序遍历就是这些数从小到大排序
- 我们可以通过旋转来维护这棵树来保证这棵树的中序遍历不变
旋转如下图所示(感性理解一下):
注意,旋转操作是(Splay)的基础
口胡一下就会发现这样并不改变二叉树的中序遍历,个人认为比较好理解,代码实现的时候注意一下不要漏掉需要修改的值就好了,这里就不详细解释了。
然后我们就可以把所有的数放到平衡树里面维护了,接下来我们来考虑平衡树怎么进行各种操作。
Splay
首先我们了解一下(rotate)操作就是如上图所示将一个节点网上旋转。
然后我们考虑到如果不做任何操作的话,平衡树可能会被卡成一条链,所以就有了(Splay)操作来控制时间复杂度。
(Splay)如下图所示:
如果当前节点对于它父亲和它父亲对于它祖父是在同一边(也就是都是左儿子或者都是右儿子)的时候,需要先(rotate)它父亲再(rotate)它自己。
如果当前节点对于它父亲和它父亲对于它祖父再不同边的时候,要(rotate)它自己两次。
为什么这样操作可以优化时间复杂度呢?我也不会证明,能用就行了。
注意:如果你不知道什么地方用(Splay),那么所有操作后都用一次就好了。
(Splay)和(rotate)操作的实现方法千奇百怪,可以写得很长也可以写得很短,所以我建议初学者结合代码理解
插入
由于平衡树的性质,左子树小于当前点,右子树大于当前点,所以我们可以从根开始搜索。
如果插入的值小于当前点,就往左边搜,如果插入的值大于当前点,那么就往右边搜。
当搜到空位时就把要插入的值插到这里。
求(x)的前驱
这个和插入的时候有点类似,也就是根据平衡树的性质,从根节点开始往下搜。
如果当前点小于(x),那么就用这个数更新(ans),然后往右搜。
如果当前点大于等于(x),那么就往左搜。
求(x)的后继
和求前驱一样,反一下就好了。
从根节点开始往下搜,如果当前点大于(x),就用这个数更新(ans),然后往左搜。
如果当前点小于等于(x),那么久往右搜。
删除
这个操作我当时想了很久(主要是因为不同删除方法以及自己想的实现方法很繁琐的困扰),大家可以结合代码去看怎么实现。
首先我们已经回求(x)的前驱和后继了,那么我们考虑一下(x)的前驱和后继中间就只有(x)这一个元素。
那么我们如果把前驱(Splay)到根再把后继(Splay)成根的右孩子(当然也可以反一反,只要把要删除的节点隔绝出来就好了),那么后继的左子树中就只有(x)这一个元素了,然后我们直接删除就好了。
区间翻转
我们先考虑一下如果要求的区间是整棵树,那么如果我们把根的左右子树换一下,那么对于中序遍历的影响就是根这个点左边的区间和右边的区间原封不动的交换一下。
那么我们就可以考虑一下,如果递归下去,把每个节点的左右子树都交换一下,那么就能做到区间翻转的效果了。
但是我们想到操作的区间不一定是整棵树,那么我们就想到删点操作时的思路,把这个区间的前驱和后继分别(Splay)一下,那么这段区间就被隔离成一棵子树了,这样就可以进行区间翻转的操作了。
然后我们想到区间反转的时间复杂度是(O(n))的,所以我们给要翻转的子树的根节点打一个懒标记,当要对这棵子树进行操作的时候,再把标记下传。
我们再考虑到左右子树交换两次等于不交换,也就是懒标记可以抵消。
求(x)的排名
这个操作你需要对于每个节点记录一个值(size),表示以这个节点为根的子树内共有多少节点,每次(Splay)的时候维护一下就好了。
然后和求前驱后继的时候一样的思路,从根节点开始搜索,(x)小于当前节点的时候往左搜,(x)大于当前节点的时候往右搜,并且在(ans)里加上左子树的(size+1)(左子树的根都小于(x),计入排名中),搜到(x)的时候停止。
那么最终答案就是(ans+1)((ans)表示比(x)小的数的个数)。
求第(k)位的数
也是要记录(size)值,从根开始搜。
如果(k)小于等于左子树的(size),那么第(k)位数就在左子树内,
如果(k)等于左子树的(size+1),那么当前节点就是第(k)位数,
如果(k)大于左子树的(size+1),那么第(k)位就在右子树里,把(k)减去((size[u].left+1)),再继续往右子树搜就好了。
总的来说(Splay)只要慢慢琢磨,不是那么难。
上代码:
#include<bits/stdc++.h>
using namespace std;
int n,opt,x;
struct aa{
int x;
int fa,cd[2];
int cnt;
int size;
}p[100009];
int len;
int rt;
void upp(int u){
int fa=p[u].fa;
int ff=p[fa].fa;
bool k=(p[fa].cd[1]==u);
p[ff].cd[p[ff].cd[1]==fa]=u;
p[u].fa=ff;
p[fa].cd[k]=p[u].cd[!k];
p[p[u].cd[!k]].fa=fa;
p[u].cd[!k]=fa;
p[fa].fa=u;
if(!p[u].fa) rt=u;
p[fa].size=p[p[fa].cd[0]].size+p[p[fa].cd[1]].size+p[fa].cnt;
p[u].size=p[p[u].cd[0]].size+p[p[u].cd[1]].size+p[u].cnt;
}
void splay(int u,int to){
while(p[u].fa!=to){
int fa=p[u].fa;
int ff=p[fa].fa;
if(ff==to) {upp(u);break;}
else if((u==p[u].cd[1])==(fa==p[ff].cd[1])) upp(fa);
else upp(u);
upp(u);
}
}
void add(int x){
int up=rt,u=rt;
while(u && p[u].x!=x){
up=u;
p[u].size++;
if(x<p[u].x) u=p[u].cd[0];
else u=p[u].cd[1];
}
if(u && p[u].x==x){p[u].cnt++;splay(u,0);}
else{
p[++len].x=x;
p[len].cnt=p[len].size=1;
p[len].fa=up;
if(x<p[up].x) p[up].cd[0]=len;
else p[up].cd[1]=len;
if(!up) rt=len;
splay(len,0);
}
}
int mn(int x){
int u=rt;
int ans;
while(u){
if(p[u].x<x){
ans=u;
u=p[u].cd[1];
}else u=p[u].cd[0];
}
splay(ans,0);
return ans;
}
int mx(int x){
int u=rt;
int ans;
while(u){
if(p[u].x>x){
ans=u;
u=p[u].cd[0];
}else u=p[u].cd[1];
}
splay(ans,0);
return ans;
}
void del(int x){
int xx=mn(x),yy=mx(x);
splay(xx,0);
splay(yy,rt);
p[p[yy].cd[0]].cnt--;
p[p[yy].cd[0]].size--;
p[yy].size--;
p[xx].size--;
if(p[p[yy].cd[0]].cnt==0) p[yy].cd[0]=0;
}
int ask(int x){
int u=rt;
int ans=0;
int up=rt;
while(u){
up=u;
if(p[u].x<x){
ans+=p[p[u].cd[0]].size+p[u].cnt;
u=p[u].cd[1];
}else u=p[u].cd[0];
}
splay(up,0);
return ans+1;
}
int ask1(int x){
int u=rt;
while(u){
if(x<=p[p[u].cd[0]].size) u=p[u].cd[0];
else if(x>p[p[u].cd[0]].size+p[u].cnt){x=x-p[p[u].cd[0]].size-p[u].cnt;u=p[u].cd[1];}
else {splay(u,0);return p[u].x;}
}
}
int main(){
add(-10000009);
add(10000009);
scanf("%d",&n);
for(int j=1;j<=n;j++){
scanf("%d%d",&opt,&x);
if(opt==1) add(x);
else if(opt==2) del(x);
else if(opt==3) printf("%d
",ask(x)-1);
else if(opt==4) printf("%d
",ask1(x+1));
else if(opt==5) printf("%d
",p[mn(x)].x);
else if(opt==6) printf("%d
",p[mx(x)].x);
}
return 0;
}