具体讲解还是看发明人陈启峰神犇的吧:http://wenku.baidu.com/link?url=Sh3e8rMJ2Pn146yz0_ClcF_bWTu9uwVEuXy8P0y-CwG-2WNmcDRehaUiuOV-4NcVQBQ9Kpwzd-TwMN3uKigQvzYXm2ZC3UPeoLuKv-Hsapa
核心代码:
版本一:(好理解)
void maintain(int &x,int flag)
{
if(flag) // right
{
// 右孩子的右子树大于左孩子
if(T[T[T[x].ch[1]].ch[1]].sz > T[T[x].ch[0]].sz) rotate(x,0);
// 右孩子的左子树大于左孩子
else if(T[T[T[x].ch[1]].ch[0]].sz > T[T[x].ch[0]].sz) rotate(T[x].ch[1],1),rotate(x,0);
else return;
}
else // left
{
// 左孩子的左子树大于右孩子
if(T[T[T[x].ch[0]].ch[0]].sz > T[T[x].ch[1]].sz) rotate(x,1);
// 右孩子的右子树大于右孩子
else if(T[T[T[x].ch[0]].ch[1]].sz > T[T[x].ch[1]].sz) rotate(T[x].ch[0],0),rotate(x,1);
else return;
}
maintain(T[x].ch[0],false);
maintain(T[x].ch[1],true);
maintain(x,false);
maintain(x,true);
}
版本二:(精简)
#include <cstdio>
#include <cstring>
using namespace std;
const int inf = 1 << 30;
const int maxn = 100000;
int sz[maxn],sn[maxn][2],val[maxn],cnt,root = 0;
int rotate(int &x,int d)
{
int k = sn[x][d ^ 1]; sn[x][d ^ 1] = sn[k][d]; sn[k][d] = x;
sz[k] = sz[x]; sz[x] = sz[sn[x][d]] + 1 + sz[sn[x][d ^ 1]];
x = k;
}
void maintain(int &x,int d)
{
if(sz[sn[sn[x][d]][d]] > sz[sn[x][d ^ 1]]) rotate(x,d ^ 1);
else if(sz[sn[sn[x][d]][d ^ 1]] > sz[sn[x][d ^ 1]]) rotate(sn[x][d],d),rotate(x,d ^ 1);
else return;
maintain(sn[x][0],false), maintain(sn[x][1],true);
maintain(x,false), maintain(x,true);
}
void ins(int &x,int v)
{
if(!x) {x = ++cnt,val[x] = v,sz[x] = 1,sn[x][0] = sn[x][1] = 0; return;}
sz[x] ++; int d = v >= val[x];
ins(sn[x][d],v);
maintain(x,d);
}
int del(int &x,int v) // 若没有v,则删除其后继
{
sz[x] --;
if(v == val[x] || (v < val[x] && !sn[x][0]) || (v > val[x] && !sn[x][1]))
{
int tmp = val[x];
if(!sn[x][0] || !sn[x][1]) x = sn[x][0] + sn[x][1];
else val[x] = del(sn[x][0],inf);
return tmp;
}
else del(sn[x][v >= val[x]],v);
}
int find(int x,int v)
{
while(x && v != val[x])
x = find(sn[x][v >= val[x]],v);
return x;
}
int rank(int x,int v)
{
if(!x) return 1;
if(v <= val[x]) return rank(sn[x][0],v);
else return sz[sn[x][0]] + 1 + rank(sn[x][1],v);
}
int main()
{
int n,m,t,opt;
scanf("%d%d",&n,&m);
for(int i = 0;i < n;i ++) scanf("%d",&t),ins(root,t);
printf("root %d
",root);
for(int i = 1;i <= cnt;i ++)
printf("T[%d] : ls : %d rs : %d val : %d
",i,sn[i][0],sn[i][1],val[i]);
for(int i = 0;i < m;i ++)
{
scanf("%d%d",&opt,&t);
if(opt == 1) printf("del %d
",del(root,t));
else if(opt == 2) printf("find %d
",val[find(root,t)]);
else if(opt == 3) printf("rank %d
",rank(root,t));
}
return 0;
}