简介
Splay是一种平衡二叉树。它通过不断地将某个节点旋转到根节点,使整棵树仍然满足二叉查找树的性质,并且保持平衡而不至于退化成链。
Splay的时间复杂度是按总复杂度来算的,具体来说,即是:
从空树开始,做插入、删除、访问操作共M次,树中最多同时存在N个点,
则总时间复杂度不超过(O(MlogN))
通常取平均值,表示为单次均摊(O(logN))
复杂度采用了《算法导论》中的摊还分析,可以看这篇博客的证明:https://blog.csdn.net/qq_31640513/article/details/76944892
属性
-
二叉查找树的性质
能够在这棵树上查找某个值的性质:左儿子的值(<)根节点的值(<)右儿子的值
-
节点维护信息
[egin{array}{llllll} hline r t & ext {tot} & f a[i] & operatorname{ch}[i][0 / 1] & v a l[i] & c n t[i] & s z[i] \ hline end{array} ]
int rt;//根节点
int tot;//节点个数
struct node {
int fa;//父亲节点
int ch[2];//子节点
int val;//权值
int cnt;//权值出现次数
int sz;//子树大小
};
这里为了表示每个节点的属性,采用了结构体的形式
方法
基本方法
maintain(x)
:在改变节点位置后,将节点(x)的(size)更新get(x)
:判断该节点是左儿子还是右儿子Clear(x)
:销毁节点(x)
//在改变节点位置后,将节点x的size更新
inline void maintain(int x) {
s[x].sz = s[s[x].ch[0]].sz+s[s[x].ch[1]].sz+s[x].cnt;
}
//判断该节点是左儿子还是右儿子
inline bool get(int x) {return x == s[s[x].fa].ch[1];}
//销毁节点x
inline void Clear(int x) {
s[x].ch[0] = s[x].ch[1] = s[x].fa = s[x].val = s[x].sz = s[x].cnt = 0;
}
旋转方法
必须保证
- 整棵树的中序遍历不变(不能破坏二叉查找树的性质)
- 受影响的节点维护的信息依然正确有效。
- root必须指向旋转后的根节点。
旋转分为两种:左旋和右旋
具体步骤分析:
设要旋转的点是(x),(x)的父亲是(y),(y)的父亲是(z)
分三步:
- (y)与(x)的子节点相连:如果(x)是(y)的左儿子,那么(x)的右儿子与(y)相连
- (x)与(y)父子相连
- (x)与 (y)的原来的父亲 (z)相连:如果(y)是(z)的左儿子,那么(z)的左儿子与(x)相连
Rotate(x)
inline void Rotate(int x) {
int y = s[x].fa, z = s[y].fa, chk = get(x);
//y与x的子节点相连
s[y].ch[chk] = s[x].ch[chk ^ 1];
s[s[x].ch[chk ^ 1]].fa = y;
//x与y父子相连
s[x].ch[chk ^ 1] = y;
s[y].fa = x;
// x与y的原来的父亲z相连
s[x].fa = z;
if(z) s[z].ch[y == s[z].ch[1]] = x;
//只有x和y的sz变化了
maintain(y);
maintain(x);
}
splay方法
每访问一个节点后都要强制将其旋转到根节点
分六种情况:
- 如果(x)的父亲是根节点,直接将(x)左旋或右旋(图(1,2) )
- 如果(x)的父亲不是根节点,且(x)和它父亲的儿子类型(
get(x)==get(f)
)相同,首先将其父亲左旋或右旋,然后将(x)右旋或左旋(图 (3,4)) - 如果(x)的父亲不是根节点,且(x)和父亲的儿子类型不同,将(x)左旋再右旋、或者右旋再左旋(图 (5,6))
splay(x)
:复杂的过程可以转为下面简单的代码
//将当前节点转移到根节点
inline void splay(int x) {
for(int f = s[x].fa; f; Rotate(x),f = s[x].fa){
if(s[f].fa) Rotate(get(x) == get(f) ? f : x);
}
rt=x;
}
因为对于当前(x)Rotate(x)
旋转方式只有一种,如果是右儿子就左旋,左儿子就右旋,减少了很多思维上的麻烦,也就不用纠结该左还是右了。
插入方法
插入方法分三种情况
- 如果树空了则直接插入根并退出
- 如果原来权值存在,权值个数加一
- 树中没有这个值,就新建节点
一定要按照二叉查找树的性质遍历树
找到了return前要进行splay()
操作,来保证树的平衡
ins(k)
//插入操作
inline void ins(int k) {
//如果树空了则直接插入根并退出
if(!rt) {
s[++tot].val = k;
s[tot].cnt++;
rt = tot;
maintain(rt);
return ;
}
int now = rt,f = 0;
while(true) {
//如果原来权值存在,权值个数加一
if(s[now].val == k) {
s[now].cnt++;
maintain(now);
maintain(f);
splay(now);
break;
}
//按照二叉查找树的性质遍历树
f = now;
now = s[now].ch[s[now].val < k];
//树中没有这个值,就新建节点
if(!now) {
s[++tot].val = k;
s[tot].cnt++;
s[tot].fa = f;
s[f].ch[s[f].val < k] = tot;
maintain(tot);
maintain(f);
splay(tot);
break;
}
}
}
查询x的排名
还是按照查找二叉树的性质进行查找
- 如果(x)比当前节点的权值小,向其左子树查找。
- 如果(x)比当前节点的权值大,将答案加上左子树$size (和当前节点)cnt$的大小,向其右子树查找。
- 如果 与当前节点的权值相同,将答案加(1)并返回。
Find(k)
//查找某个数 返回这个数是第几个
inline int Find(int k) {
int res = 0,now = rt;
while(true) {
//如果这个数比当前节点小,搜索左子树
if(k<s[now].val) {
now = s[now].ch[0];
}else {
//否则加上右子树的个数
res += s[s[now].ch[0]].sz;
//中序遍历,如果找到这个节点返回res+1
if(k == s[now].val) {
splay(now);
return res + 1;
}
res += s[now].cnt;
now = s[now].ch[1];
}
}
}
查询排名x的数
- 如果左子树非空且剩余排名(k)不大于左子树的大小 $ size $,那么向左子树查找。
- 否则将(k)减去左子树的和根的大小。如果此时(k)的值小于等于(size),则返回根节点的权值,否则继续向右子树查找。
getKth(k)
//查询第k个数
inline int getKth(int k) {
int now = rt;
while(true){
if(s[now].ch[0] && k <= s[s[now].ch[0]].sz){
now = s[now].ch[0];
}else{
k -= s[now].cnt + s[s[now].ch[0]].sz;
if(k <= 0){
splay(now);
return s[now].val;
}
now=s[now].ch[1];
}
}
}
查询前驱和后继
getPre()
:查询小于x的最大的数的节点,就是找左儿子的右链
getNxt()
:查询大于x的最小的数的节点,就是找右儿子的左链
//查询小于x的最大的数的节点,就是找左儿子的右链
inline int getPre() {
int now = s[rt].ch[0];
while (s[now].ch[1]) now = s[now].ch[1];
return now;
}
//查询大于x的最小的数的节点,同理
inline int getNxt() {
int now = s[rt].ch[1];
while (s[now].ch[0]) now = s[now].ch[0];
return now;
}
删除方法
删除方法具体步骤:
- 首先将(x)旋转到根的位置,要用到
Find(x)
先找到(x) - 如果大于(1),则不需要删除节点,只需要将(cnt-1)
- 如果只有一个点,删除这个点之后,将rt变为(0)
- 如果左右一个儿子,就将该点删除,并让那一个儿子成为根节点
- 否则就将(x)的前驱旋转到根节点,并将(x)的右儿子与根节点相连,将(x)删除。
del(x)
inline void del(int k){
Find(k);//先让该点成为根节点
if(s[rt].cnt > 1) {//如果大于1,不需要删除节点
s[rt].cnt--;
maintain(rt);
return;
}
//如果只有一个点
if(!s[rt].ch[0] && !s[rt].ch[1]){
Clear(rt);
rt = 0;
return;
}
//没有左儿子,让右儿子成为根节点
if(!s[rt].ch[0]){
int tmp = rt;
rt = s[rt].ch[1];
s[rt].fa=0;
Clear(tmp);
return;
}
//没有右儿子,让左儿子成为根节点
if(!s[rt].ch[1]){
int tmp = rt;
rt = s[rt].ch[0];
s[rt].fa = 0;
Clear(tmp);
return;
}
//有左右儿子,让前驱成为根节点
int x = getPre() , now = rt;
splay(x);
s[s[now].ch[1]].fa = x;
s[x].ch[1] = s[now].ch[1];
Clear(now);
maintain(rt);
}
模板题
https://www.luogu.com.cn/problem/P3369
完整代码:
#include<bits/stdc++.h>
using namespace std;
const int N = 1e5+7;
int rt;//根节点
int tot;//节点个数
struct node {
int fa;//父亲节点
int ch[2];//子节点
int val;//权值
int cnt;//权值出现次数
int sz;//子树大小
}s[N];
struct Splay {
//在改变节点位置后,将节点x的size更新
inline void maintain(int x) {
s[x].sz = s[s[x].ch[0]].sz+s[s[x].ch[1]].sz+s[x].cnt;
}
//判断该节点是左儿子还是右儿子
inline bool get(int x) {return x == s[s[x].fa].ch[1];}
//销毁节点x
inline void Clear(int x) {
s[x].ch[0] = s[x].ch[1] = s[x].fa = s[x].val = s[x].sz = s[x].cnt = 0;
}
inline void Rotate(int x) {
int y = s[x].fa, z = s[y].fa, chk = get(x);
//y与x的子节点相连
s[y].ch[chk] = s[x].ch[chk ^ 1];
s[s[x].ch[chk ^ 1]].fa = y;
//x与y父子相连
s[x].ch[chk ^ 1] = y;
s[y].fa = x;
// x与y的原来的父亲z相连
s[x].fa = z;
if(z) s[z].ch[y == s[z].ch[1]] = x;
//只有x和y的sz变化了
maintain(y);
maintain(x);
}
//将当前节点转移到根节点
inline void splay(int x) {
for(int f = s[x].fa; f; Rotate(x),f = s[x].fa){
if(s[f].fa) Rotate(get(x) == get(f) ? f : x);
}
rt=x;
}
//插入操作
inline void ins(int k) {
//如果树空了则直接插入根并退出
if(!rt) {
s[++tot].val = k;
s[tot].cnt++;
rt = tot;
maintain(rt);
return ;
}
int now = rt,f = 0;
while(true) {
//如果原来权值存在,权值个数加一
if(s[now].val == k) {
s[now].cnt++;
maintain(now);
maintain(f);
splay(now);
break;
}
//按照二叉查找树的性质遍历树
f = now;
now = s[now].ch[s[now].val < k];
//树中没有这个值,就新建节点
if(!now) {
s[++tot].val = k;
s[tot].cnt++;
s[tot].fa = f;
s[f].ch[s[f].val < k] = tot;
maintain(tot);
maintain(f);
splay(tot);
break;
}
}
}
//查找某个数 返回这个数是第几个
inline int Find(int k) {
int res = 0,now = rt;
while(true) {
//如果这个数比当前节点小,搜索左子树
if(k<s[now].val) {
now = s[now].ch[0];
}else {
//否则加上右子树的个数
res += s[s[now].ch[0]].sz;
//中序遍历,如果找到这个节点返回res+1
if(k == s[now].val) {
splay(now);
return res + 1;
}
res += s[now].cnt;
now = s[now].ch[1];
}
}
}
//查询小于x的最大的数的节点,就是找左儿子的右链
inline int getPre() {
int now = s[rt].ch[0];
while (s[now].ch[1]) now = s[now].ch[1];
return now;
}
//查询大于x的最小的数的节点,同理
inline int getNxt() {
int now = s[rt].ch[1];
while (s[now].ch[0]) now = s[now].ch[0];
return now;
}
//查询第k个数
inline int getKth(int k) {
int now = rt;
while(true){
if(s[now].ch[0] && k <= s[s[now].ch[0]].sz){
now = s[now].ch[0];
}else{
k -= s[now].cnt + s[s[now].ch[0]].sz;
if(k <= 0){
splay(now);
return s[now].val;
}
now=s[now].ch[1];
}
}
}
//删除结点
inline void del(int k){
Find(k);//先让该点成为根节点
if(s[rt].cnt > 1) {//如果大于1,不需要删除节点
s[rt].cnt--;
maintain(rt);
return;
}
//如果只有一个点
if(!s[rt].ch[0] && !s[rt].ch[1]){
Clear(rt);
rt = 0;
return;
}
//没有左儿子,让右儿子成为根节点
if(!s[rt].ch[0]){
int tmp = rt;
rt = s[rt].ch[1];
s[rt].fa=0;
Clear(tmp);
return;
}
//没有右儿子,让左儿子成为根节点
if(!s[rt].ch[1]){
int tmp = rt;
rt = s[rt].ch[0];
s[rt].fa = 0;
Clear(tmp);
return;
}
//有左右儿子,让前驱成为根节点
int x = getPre() , now = rt;
splay(x);
s[s[now].ch[1]].fa = x;
s[x].ch[1] = s[now].ch[1];
Clear(now);
maintain(rt);
}
}st;
int main(){
int n,opt,x;
scanf("%d",&n);
while(n--){
scanf("%d%d",&opt,&x);
if(opt == 1) st.ins(x);
else if(opt == 2) st.del(x);
else if(opt == 3) printf("%d
",st.Find(x));
else if(opt == 4) printf("%d
",st.getKth(x));
else if(opt == 5) {
st.ins(x);
printf("%d
",s[st.getPre()].val);
st.del(x);
}
else {
st.ins(x);
printf("%d
",s[st.getNxt()].val);
st.del(x);
}
}
return 0;
}
oiwiki上的代码没有用结构体写起来比较快:
#include <cstdio>
const int N = 100005;
int rt, tot, fa[N], ch[N][2], val[N], cnt[N], sz[N];
struct Splay {
void maintain(int x) { sz[x] = sz[ch[x][0]] + sz[ch[x][1]] + cnt[x]; }
bool get(int x) { return x == ch[fa[x]][1]; }
void clear(int x) {
ch[x][0] = ch[x][1] = fa[x] = val[x] = sz[x] = cnt[x] = 0;
}
void rotate(int x) {
int y = fa[x], z = fa[y], chk = get(x);
ch[y][chk] = ch[x][chk ^ 1];
fa[ch[x][chk ^ 1]] = y;
ch[x][chk ^ 1] = y;
fa[y] = x;
fa[x] = z;
if (z) ch[z][y == ch[z][1]] = x;
maintain(x);
maintain(y);
}
void splay(int x) {
for (int f = fa[x]; f = fa[x], f; rotate(x))
if (fa[f]) rotate(get(x) == get(f) ? f : x);
rt = x;
}
void ins(int k) {
if (!rt) {
val[++tot] = k;
cnt[tot]++;
rt = tot;
maintain(rt);
return;
}
int cnr = rt, f = 0;
while (1) {
if (val[cnr] == k) {
cnt[cnr]++;
maintain(cnr);
maintain(f);
splay(cnr);
break;
}
f = cnr;
cnr = ch[cnr][val[cnr] < k];
if (!cnr) {
val[++tot] = k;
cnt[tot]++;
fa[tot] = f;
ch[f][val[f] < k] = tot;
maintain(tot);
maintain(f);
splay(tot);
break;
}
}
}
int rk(int k) {
int res = 0, cnr = rt;
while (1) {
if (k < val[cnr]) {
cnr = ch[cnr][0];
} else {
res += sz[ch[cnr][0]];
if (k == val[cnr]) {
splay(cnr);
return res + 1;
}
res += cnt[cnr];
cnr = ch[cnr][1];
}
}
}
int kth(int k) {
int cnr = rt;
while (1) {
if (ch[cnr][0] && k <= sz[ch[cnr][0]]) {
cnr = ch[cnr][0];
} else {
k -= cnt[cnr] + sz[ch[cnr][0]];
if (k <= 0) return val[cnr];
cnr = ch[cnr][1];
}
}
}
int pre() {
int cnr = ch[rt][0];
while (ch[cnr][1]) cnr = ch[cnr][1];
return cnr;
}
int nxt() {
int cnr = ch[rt][1];
while (ch[cnr][0]) cnr = ch[cnr][0];
return cnr;
}
void del(int k) {
rk(k);
if (cnt[rt] > 1) {
cnt[rt]--;
maintain(rt);
return;
}
if (!ch[rt][0] && !ch[rt][1]) {
clear(rt);
rt = 0;
return;
}
if (!ch[rt][0]) {
int cnr = rt;
rt = ch[rt][1];
fa[rt] = 0;
clear(cnr);
return;
}
if (!ch[rt][1]) {
int cnr = rt;
rt = ch[rt][0];
fa[rt] = 0;
clear(cnr);
return;
}
int x = pre(), cnr = rt;
splay(x);
fa[ch[cnr][1]] = x;
ch[x][1] = ch[cnr][1];
clear(cnr);
maintain(rt);
}
} tree;
int main() {
int n, opt, x;
for (scanf("%d", &n); n; --n) {
scanf("%d%d", &opt, &x);
if (opt == 1)
tree.ins(x);
else if (opt == 2)
tree.del(x);
else if (opt == 3)
printf("%d
", tree.rk(x));
else if (opt == 4)
printf("%d
", tree.kth(x));
else if (opt == 5)
tree.ins(x), printf("%d
", val[tree.pre()]), tree.del(x);
else
tree.ins(x), printf("%d
", val[tree.nxt()]), tree.del(x);
}
return 0;
}