简述
二叉搜索树是一种满足对于每个节点,其左儿子权值小于其权值,右儿子大于其权值的树形结构。它在随机情况下能表现较优秀,但如果数据使得树形结构变得非常不平衡,其复杂度就会大幅度退化。这时候为了保证其平衡性质,各种平衡树应运而生。替罪羊树是其中一种优雅的暴力,它的中心思想是当某一个节点不够平衡的时候立刻重构。判断的标准是设定一个属于((0.5, 1))的参数(alpha),若某个儿子的大小大于父亲大小乘以(alpha),则说明不够平衡需要重构。
大致变量名:
const double alpha=0.7;
const int INF=0x7fffffff;//用于防止边界问题
int a[maxn];//一个用于存中序遍历的数组
struct ScapegoatTree
{
int l,r;//左右儿子
int val,size,sum,cnt;//当前权值,子树大小,子树中真实存在的数据个数(可重集元素个数),当前节点的相同数据个数(注意区分size和sum)
ScapegoatTree() {l=r=0;}
#define l(i) (t[i].l)
#define r(i) (t[i].r)
#define val(i) (t[i].val)
#define size(i) (t[i].size)
#define sum(i) (t[i].sum)
#define cnt(i) (t[i].cnt)
}t[maxn];
初始化
一般为了避免讨论一些边界问题,平衡树都会初始化插入一个负无穷和一个正无穷。
void update(int p)
{
size(p)=size(l(p))+size(r(p))+1;
sum(p)=sum(l(p))+sum(r(p))+cnt(p);
}
int New(int val)
{
val(++tot)=val;
size(tot)=sum(tot)=cnt(tot)=1;
return tot;
}
void init()
{
size(0)=sum(0)=cnt(0)=0;
root=New(-INF);
r(root)=New(INF);
update(root);
}
重构
一般来说,(alpha)取0.7或0.8。(这是一个稍微有点玄学的东西,差不多就行了)
bool check(int p) {return !cnt(p)||alpha*size(p)<(double)max(size(l(p)), size(r(p)));}
重构分为以下两个部分:将原子树按中序遍历展开为数组,然后每次选中间的数作为根递归建树并自底向上地更新信息。按照中序遍历是由于二叉搜索树的性质保证了其中序遍历为根据权值有序的数组,重构时也将满足该性质。
void flatten(int p,int &tot)
{
if (!p)
return ;
flatten(l(p), tot);
if (cnt(p))
a[++tot]=p;
flatten(r(p), tot);
}
void build(int &p,int l,int r)
{
if (l>r)
{
p=0;
return ;
}
int mid=(l+r)>>1;
p=a[mid];
build(l(p), l, mid-1),build(r(p), mid+1, r);
update(p);
}
void rebuild(int &p)
{
int tot=0;
flatten(p, tot);
build(p, 1, tot);
}
插入
和其它平衡树差不多,都是空了直接新建节点存,比当前小往左跑,比当前大往右跑,相等直接增加个数。
void insert(int &p,int val)
{
if (!p)
{
p=New(val);
return ;
}
if (val==val(p))
cnt(p)++;
else if (val<val(p))
insert(l(p), val);
else
insert(r(p), val);
update(p);
if (check(p))
rebuild(p);
}
删除
与插入同理。其实并不需要使用惰性删除,用摊还分析或势能分析就可以发现当节点为空时直接将其子树重构不影响其复杂度。
void remove(int &p,int val)
{
if (!p)
return ;
if (val==val(p))
{
if (cnt(p))
cnt(p)--;
}
else if (val<val(p))
remove(l(p), val);
else
remove(r(p), val);
upadte(p);
if (check(p))
rebuild(p);
}
查询前驱/后继
一路往查询的值靠近顺便更新答案,若找到相等的节点则答案在其左子树的最右节点(前驱)或右子树的最左节点(后继)
int getPre(int val)
{
int ans=1,p=root;//val(1)=-INF
while(p)
{
if (val==val(p))
{
if (l(p))
{
p=l(p);
while(r(p))
p=r(p);
ans=p;
}
break;
}
if (val(p)<val&&val(p)>val(ans))
ans=p;
p=val<val(p)?l(p):r(p);
}
return val(ans);
}
int getNxt(int val)
{
int ans=2,p=root;//val(2)=INF
while(p)
{
if (val==val(p))
{
if (r(p))
{
p=r(p);
while(l(p))
p=l(p);
ans=p;
}
break;
}
if (val(p)>val&&val(p)<val(ans))
ans=p;
p=val<val(p)?l(p):r(p);
}
return val(ans);
}
通过排名查询权值
直接看代码吧。
int getRank(int p,int val)
{
if (val==val(p))
return sum(l(p))+1;
return val<val(p)?getRank(l(p), val):getRank(r(p), val)+sum(l(p))+cnt(p);
}
查询某权值的排名
还是看代码,你可以的。
int getVal(int p,int rank)
{
if (sum(l(p))+1<=rank&&rank<=sum(l(p))+cnt(p))
return val(p);
return rank<=sum(l(p))?getVal(l(p), rank):getVal(r(p), rank-sum(l(p))-cnt(p));
}
模板
参考代码:[click]
#include <cstdio>
#include <cctype>
const int INF=0x7fffffff;
const int maxn=2e6+10;
const double alpha=0.7;
int a[maxn];
int root,tot;
struct ScapegoatTree
{
int l,r;
int val,size,sum,cnt;
ScapegoatTree() {l=r=0;}
#define l(i) (t[i].l)
#define r(i) (t[i].r)
#define val(i) (t[i].val)
#define size(i) (t[i].size)
#define sum(i) (t[i].sum)
#define cnt(i) (t[i].cnt)
}t[maxn];
int max(int x,int y) {return x>y?x:y;}
bool check(int p) {return !cnt(p)||alpha*size(p)<(double)max(size(l(p)), size(r(p)));}
int read()
{
int res=0;
char ch=getchar();
while(!isdigit(ch))
ch=getchar();
while(isdigit(ch))
res=res*10+ch-'0',ch=getchar();
return res;
}
void update(int p)
{
size(p)=size(l(p))+size(r(p))+1;
sum(p)=sum(l(p))+sum(r(p))+cnt(p);
}
int New(int val)
{
val(++tot)=val;
size(tot)=sum(tot)=cnt(tot)=1;
return tot;
}
void init()
{
size(0)=sum(0)=cnt(0)=0;
root=New(-INF);
r(root)=New(INF);
update(root);
}
void flatten(int p,int &tot)
{
if (!p)
return ;
flatten(l(p), tot);
if (cnt(p))
a[++tot]=p;
flatten(r(p), tot);
}
void build(int &p,int l,int r)
{
if (l>r)
{
p=0;
return ;
}
int mid=(l+r)>>1;
p=a[mid];
build(l(p), l, mid-1),build(r(p), mid+1, r);
update(p);
}
void rebuild(int &p)
{
int tot=0;
flatten(p, tot);
build(p, 1, tot);
}
void insert(int &p,int val)
{
if (!p)
{
p=New(val);
return ;
}
if (val==val(p))
cnt(p)++;
else if (val<val(p))
insert(l(p), val);
else
insert(r(p), val);
update(p);
if (check(p))
rebuild(p);
}
void remove(int &p,int val)
{
if (!p)
return ;
if (val==val(p))
cnt(p)--;
else
val<val(p)?remove(l(p), val):remove(r(p), val);
update(p);
if (check(p))
rebuild(p);
}
int getPre(int val)
{
int ans=1,p=root;
while(p)
{
if (val==val(p))
{
if (l(p))
{
p=l(p);
while(r(p))
p=r(p);
ans=p;
}
break;
}
if (val(p)<val&&val(p)>val(ans))
ans=p;
p=val<val(p)?l(p):r(p);
}
return val(ans);
}
int getNxt(int val)
{
int ans=2,p=root;
while(p)
{
if (val==val(p))
{
if (r(p))
{
p=r(p);
while(l(p))
p=l(p);
ans=p;
}
break;
}
if (val(p)>val&&val(p)<val(ans))
ans=p;
p=val<val(p)?l(p):r(p);
}
return val(ans);
}
int getRank(int p,int val)
{
if (val==val(p))
return sum(l(p))+1;
return val<val(p)?getRank(l(p), val):getRank(r(p), val)+sum(l(p))+cnt(p);
}
int getVal(int p,int rank)
{
if (sum(l(p))+1<=rank&&rank<=sum(l(p))+cnt(p))
return val(p);
return rank<=sum(l(p))?getVal(l(p), rank):getVal(r(p), rank-sum(l(p))-cnt(p));
}
int main()
{
int n=read(),m=read();
init();
for (int i=1;i<=n;i++)
insert(root, read());
int ans=0,last=0;
for (int i=1;i<=m;i++)
{
int op=read(),x=read()^last;
if (op==1)
insert(root, x);
else if (op==2)
remove(root, x);
else if (op==3)
{
insert(root, x);
ans^=(last=getRank(root, x)-1);
remove(root, x);
}
else if (op==4)
ans^=(last=getVal(root, x+1));
else if (op==5)
ans^=(last=getPre(x));
else
ans^=(last=getNxt(x));
}
printf("%d
",ans);
return 0;
}