zoukankan      html  css  js  c++  java
  • 平衡树Splay

    平衡树

    (splay)

    又叫伸展树
    一种比较玄学的平衡树,实现原理基于更加玄学的(splay)(伸展)函数
    理论复杂度为(O(nlogn)),但实现中含较大常数因子
    在随机数据下表现最差,众多平衡树中时间效率垫底(跟我一样,菜的离谱
    代码实现较为复杂,但其优势也在于它的劣势之处,
    正因为有了使常数因子变大的splay函数,也造就了(spaly)强大的区间处理能力,让它在一亿种平衡树中占有一席之地

    (Code)

    (P3369)
    优美的代码……(别人的

    #include<cstdio>
    #include<cstdlib>
    #include<cmath>
    #include<ctime>
    #include<cstring>
    #include<iostream>
    #include<algorithm>
    #include<stack>
    #include<queue>
    #include<vector>
    #include<bitset>
    #include<set>
    #include<map>
    #define LL long long
    #define rg register
    #define il inline
    #define us unsigned
    #define eps 1e-6
    #define INF 0x3f3f3f3f
    #define ls k<<1
    #define rs k<<1|1
    #define tmid ((tr[k].l+tr[k].r)>>1)
    #define nmid ((l+r)>>1)
    #define Thispoint tr[k].l==tr[k].r
    #define pub push_back
    #define lth length
    #define pii pair<int,int>
    #define mkp make_pair
    using namespace std;
    inline void Read(int &x){
    	int f=1;
    	char c=getchar();
    	x=0;
    	while(c<'0'||c>'9'){
    		if(c=='-')f=-1;
    		c=getchar();
    	}
    	while(c>='0'&&c<='9'){
    		x=(x<<3)+(x<<1)+c-'0';
    		c=getchar();
    	}
    	x*=f;
    }
    
    template<typename _T>
    inline void read(_T &x)
    {
    	x= 0 ;char s=getchar();int f=1;
    	while(s<'0'||'9'<s){f=1;if(s=='-')f=-1;s=getchar();}
    	while('0'<=s&&s<='9'){x=(x<<1)+(x<<3)+s-'0';s=getchar();}
    	x*=f;
     } 
     
    const int p=5e5+5;
    
    #define lson(u) t[u].ch[0]
    #define rson(u) t[u].ch[1]
    #define f(u) t[u].ff
    #define v(u) t[u].val
    #define c(u) t[u].cnt
    #define s(u) t[u].son
    
    int root  = 0,tot = 0;
    
    struct node{
    	int ch[2];
    	int ff;
    	int  val;
    	int cnt;
    	int son;
    	
    }t[p]; 
    
    inline void pushup(int u)// 合并函数
    {
    	s(u)= s(lson(u))+s(rson(u))+c(u);
    }
    
    inline void rotate(int x)//旋转函数,zig、zag自动化
    {
    	int y = f(x);
    	int z = f(f(x));// x->fa->fa;
    	int k = rson(f(x))==x; // if x is f(u)'s  rson
    	t[z].ch[rson(z) == f(x)] = x;f(x) = z;
    	t[y].ch[k] = t[x].ch[k^1];f(t[x].ch[k^1]) = y;
    	t[x].ch[k^1] = y;f(y) = x; 
    	pushup(y);pushup(x);
    }
    
    inline void splay(int x,int goal)//高效的伸展操作,对了goal不是你要转到的位置,而是你的目的地的父亲
    {
    	while(f(x)!=goal)
    	{
    		int y = f(x);
    		int z = f(y);
    		if(z!=goal)
    		(lson(y) == x)^(lson(z)==y)?rotate(x):rotate(y);
    		rotate(x);
    	}
    	if(goal == 0)
    	root = x;
     } 
     
    inline void insert(int x)
    {
    	int u = root,ff = 0;
    	while(u&&t[u].val!=x)
    	{
    		ff = u;
    		u = t[u].ch[x > t[u].val];
    	}
    	if(u) c(u)++;
    	else
    	{
    		u =++tot;
    		if(ff) t[ff].ch[x>t[ff].val]  =u;//注意这里的特判是为了插入INF与-INF,防止0节点有儿子,否则可以预料会WA的莫名其妙
    		lson(tot) =rson(tot) = 0;
    		f(tot) = ff;v(tot)= x;
    		c(tot)=1;s(tot)=1;
    	}
    	splay(u,0);
    }
    
    void Find(int x)
    {
    	int u = root;
    	if(!u) return ;
    	while(t[u].ch[x > v(u)]&&(x!=v(u)))  u = t[u].ch[x>v(u)];
    	splay(u,0);
    }
    
    int prenxt(int x,int f)
    {
    	Find(x);
    	int u =root;
    	if((v(u)>x&&f)||((v(u)<x)&&!f)){return u;}
    	u = t[u].ch[f];
    	while(t[u].ch[f^1]) u= t[u].ch[f^1];
    	return u;
    }
    
    void Delete(int x)
    {
    	int last = prenxt(x,0);
    	int nxt = prenxt(x,1);
    	splay(last,0);
    	splay(nxt,last);
    	int del = lson(nxt);
    	if(c(del) > 1)
    	{
    		c(del)--;
    		splay(del,0);
    	}
    	else lson(nxt) = 0;
    }
    
    int k_th(int x)
    {
    	int u =root;
    	if(s(u)<x) return 0;
    	while(23333)
    	{
    		int y = lson(u);
    		if(x>c(u)+s(y))
    		{
    			x-=c(u)+s(y);
    			u = rson(u);
    		}
    		else
    			if(s(y)>=x) u = y;
    			else return v(u);
    	}
    }
    
    signed main()
    {
    	int n;
    	
    	insert(-INF);
    	insert(INF);
    	read(n);
    	
    	for(int i=1;i<=n;i++)
    	{
    		int opt;
    		read(opt);
    		int x;
    		read(x);
    		if(opt == 1) {insert(x);}
    		if(opt == 2){Delete(x);}
    		if(opt==3){	Find(x);printf("%d
    ",s(lson(root)));}
    		if(opt == 4){printf("%d
    ",k_th(x+1));}
    		if(opt == 5){printf("%d
    ",v(prenxt(x,0)));}
    		if(opt == 6){printf("%d
    ",v(prenxt(x,1)));}
    		
    	}
    }
    

    丑陋的指针版代码(不过这个好像比上面的快(48ms)

    #include<iostream>
    #include<cstdio>
    #include<cstring>
    
    using namespace std;
    
    #define int long long
    #define INF 1<<30
    
    template<typename _T>
    inline void read(_T &x)
    {
    	x=0;char s=getchar();int f=1;
    	while(s<'0'||s>'9') {f=1;if(s=='-')f=-1;s=getchar();}
    	while('0'<=s&&s<='9'){x=(x<<3)+(x<<1)+s-'0';s=getchar();}
    	x*=f;
    }
    
    struct node{
    	int key;
    	int sum;
    	int cnt;
    	node *ls,*rs,*fa;	
    }; 
    
    node mem[233333],*pool = mem,*aux,*rot;
    node *New(){return pool++;}
    
    void hb(node *a)
    {
    	a->sum = a->cnt;
    	if(a->ls)
    	a->sum += a->ls->sum ;
    	if(a->rs) 
    	a->sum+= a->rs->sum;
    }
    
    void zig(node *x)
    {
    	node *y = x->fa;x->fa = y->fa;
    	y->ls = x->rs;
    	if(y->ls) y->ls->fa = y;
    	y->fa = x;x->rs = y;
    	hb(y);hb(x);
    	if(x->fa!=NULL)
    	{
    		if(x->fa->ls == y)
    			x->fa->ls=x;
    		else x->fa->rs = x;
    	}
    }
    
    void zag(node *x)
    {
    	node *y=x->fa;x->fa=y->fa;
    	y->rs = x->ls;
    	
    	if(y->rs) y->rs->fa = y;
    	
    	y->fa = x;x->ls = y;
    	hb(y);hb(x);
    	if(x->fa!=NULL)
    	{
    		if(x->fa->ls == y)
    			x->fa->ls=x;
    		else x->fa->rs = x;
    	}
    }
    inline void splay(node *x,node *S)
    {
    	if(x == S) return ;
    	node *op =S->fa;
    	while(23333)
    	{
    		node *y = x->fa;
    		node *z = x->fa->fa;
    		if(y == S)
    		{
    			if(y->ls == x) zig(x);
    			else zag(x);			
    			break;
    		}
    
    		if(y->ls == x)
    		{
    			if(z->ls == y)
    			{
    				zig(y);
    				zig(x);
    			}
    			else
    			{
    				zig(x);
    				zag(x);
    			}
    		}
    		else
    		{
    			if(z->ls == y)
    			{		
    				zag(x);
    				zig(x);
    			}
    			else
    			{
    				zag(y);
    				zag(x);
    			}
    		}
    		if(z == S) 
    		{
    			break;
    		}
    	}
    }
    
    node *make(int val)
    {
    	node *s= New();
    	s->key = val;
    	s->sum = 1;
    	s->cnt =1;
    	s->rs = s->ls=NULL;
    }
    inline void dh(node *z)
    {
    	int ui ;
    	if(z->key == 451009)
    	{
    	cout<<(z->key)<<" ";
    	cout<<(&z);		
    	}
    
    }
    int pity;
    inline node *insert(int x,node *op)
    {
    	
    	node *splay;
    	if(op->key==x) 
    	{
    		op->cnt++;
    		hb(op);
    		return op;
    	}
    	if(op->key<x){
    		if(!op->rs){op->rs=make(x);op->rs->fa = op;splay = op->rs;}
    		else {splay = insert(x,op->rs);}
    	}
    	if(op->key>x){
    		if(!op->ls){op->ls=make(x);op->ls->fa = op;splay = op->ls;}
    		else {splay = insert(x,op->ls);}		
    	}
    	hb(op);
    	return splay;
    }
    
    inline node *Find(int x,node *op)
    {
    	if(op->key!=x)
    	{
    		if(op->key < x) return Find(x,op->rs);
    		if(op->key > x) return Find(x,op->ls);
    	}
    	else return op;
    }
    
    inline int qRank(int x)
    {
    	node *op = Find(x,rot);
    	splay(op,rot);
    	rot = op;
    	if(op->ls) return op->ls->sum+1;
    	else return 1; 
    }
    
    inline int rankq(int x,node *sp)
    {
    	int ans=0;
    	if(sp->ls) ans = sp->ls->sum;
    	else ans = 0;
    	
    	if(ans+sp->cnt < x){return rankq(x- ans - sp->cnt,sp->rs);}
    	else
    	{
    		if(x<=ans) return rankq(x,sp->ls);
    		else return sp->key;
    	}
    }
    
    inline node *mAx(node *s){return (s->rs)? mAx(s->rs) : s;}
    inline node *mIn(node *s){return (s->ls)? mIn(s->ls) : s;}
    
    inline void Delete(int x)
    {
    	node *op = Find(x,rot);
    	splay(op,rot);
    	rot = op;
    	if(rot->cnt > 1) {rot->cnt--;hb(rot);return;}
    	if(!rot->ls){rot = rot->rs;rot->fa = 0;return;}
    	node *s = mAx(rot->ls);
    	splay(s,rot->ls);
    	s->rs = rot->rs;
    	if(s->rs)s->rs->fa = s;
    	s->fa=0;
    	hb(s);
    	rot = s;
    }
    
    inline int pre(int x)
    {
    	node op = *rot;
    	int ans = -INF;
    	while(23333)
    	{
    		if(op.key < x) 
    		{
    			ans = max(op.key , ans);
    			if(op.rs!=NULL) op = *op.rs;
    			else break;
    		}
    		else 
    		{
    			if(op.ls!=NULL)op = *op.ls;
    			else break;
    		}
    	}
    	return ans;	
    }
    
    inline int nxt(int x)
    {
    	node op = *rot;
    	int ans = INF;
    	while(2333)
    	{
    		if(op.key > x) 
    		{
    			ans = min(op.key , ans);
    			if(op.ls!=NULL) op = *op.ls;
    			else break;
    		}
    		else 
    		{
    			if(op.rs!=NULL)op = *op.rs;
    			else break;
    		}
    	}
    	return ans;	
    }
    
    signed main()
    {
    	rot = make(INF);
    	rot->sum = rot->cnt = 0;
    	insert(-INF,rot);
    	rot->ls->sum = rot->ls->cnt = 0;
    	hb(rot);
    	int n;
    	
    	read(n);
    	
    	for(int i=1,op,val;i<=n;i++)
    	{
    		pity = i;
    		read(op);
    		read(val);
    		if(op == 1) 
    		{
    			aux = insert(val,rot);
    			splay(aux,rot);
    			rot = aux;	
    		}
    		if(op == 2)	Delete(val);
    		if(op == 3) printf("%lld
    ",qRank(val));
    		if(op == 4) printf("%lld
    ",rankq(val,rot));//cout<<(qrank(val))<<'
    ';
    		if(op == 5) printf("%lld
    ",pre(val));
    		if(op == 6) printf("%lld
    ",nxt(val));
    		
    		
    	}	
    

    文艺平衡树

    要求实现区间翻转操作,
    这个题我一开始一直都想不明白怎么用splay维护
    直到看了这篇博客
    突然意识到自己思维僵化严重……
    根本想不到跳出值域平衡树的圈子

    总之就是在(splay)的时候不再在意每个点的权值,这样才能从一棵至于平衡树进化为区间树

    (detail)

    维护序列前插入哨兵节点(1)(n+2)

    #include<iostream>
    #include<cstdio>
    #include<cstring>
    #include<cmath>
    #include<cstring>
    
    using namespace std;
    
    #define int long long
    #define INF 1<<30
    
    template<typename _T>
    inline void read(_T &x)
    {
    	x=0;char s=getchar();int f=1;
    	while(s<'0'||s>'9') {f=1;if(s=='-')f=-1;s=getchar();}
    	while('0'<=s&&s<='9'){x=(x<<3)+(x<<1)+s-'0';s=getchar();}
    	x*=f;
    }
    int n,m;
    
    #define lson(x) t[x].ch[0]
    #define rson(x) t[x].ch[1]
    #define s(x) t[x].size
    #define vl(x) t[x].v
    #define f(x) t[x].ff
    #define tag(x) t[x].mark
    
    const int p=1e5+5;
    
    struct Node{
    	int ch[2];
    	int ff,v;
    	int size;
    	int mark;
    	
    	void init(int x,int fa)
    	{
    		ch[0] = ch[1] = 0;
    		size = 1;v = x;ff =fa;
    	}
    }t[p];
    
    int root,M,tot;
    
    inline void pushup(int x)
    {
    	s(x) = s(lson(x))+s(rson(x))+1;	
    }
    
    inline void pushdown(int x)
    {
    	if(t[x].mark)
    	{
    		tag(lson(x))^=1;
    		tag(rson(x))^=1;
    		tag(x) = 0;
    		swap(lson(x),rson(x));
    	}
    }
    
    inline void rotate(int x)
    {
    	int y = f(x);
    	int z = f(y);
    	int k = rson(y)==x;
    	t[z].ch[rson(z)==y] = x;f(x) = z;
    	t[y].ch[k] = t[x].ch[k^1];f(t[x].ch[k^1]) = y;
    	t[x].ch[k^1] = y;f(y) = x;
    	pushup(y);
    	pushup(x);
    }
    
    inline void splay(int x,int goal)
    {
    	while(f(x)!=goal)
    	{
    		int z = f(f(x));
    		int y = f(x);
    		if(z!=goal) (rson(z)==y)^(rson(y) == x)?rotate(x):rotate(y);
    		rotate(x);
    	}
    	if(goal == 0) root = x;
    }
    
    inline void insert(int x)
    {
    	int u = root;
    	int fa = 0;
    	while(u){fa = u;u = t[u].ch[vl(u) < x];}
    	u = ++tot;
    	if(fa) t[fa].ch[vl(fa) < x] = u;
    	t[u].init(x,fa);
    	splay(u,0);
    }
    
    inline int k_th(int x)
    {
    	int u =root;
    	while(2333)
    	{
    		pushdown(u);
    		int y = lson(u);
    		pushup(u);
    		if(x <= s(y)) {u = lson(u);continue;}
    		if(s(y)+1 == x) return u;
    		else  x-=s(y)+1,u = rson(u);
    		
    	}
    }
    
    void print(int u)
    {
    	pushdown(u);
    	if(lson(u)) print(lson(u));
    	if(vl(u)>1 && vl(u)<n+2) printf("%d ",vl(u)-1);
    	if(rson(u)) print(rson(u));
     } 
    
    inline void solve(int l,int r)
    {
    	l =k_th(l);
    	r = k_th(r+2);
    	splay(l,0);
    	splay(r,l);
    	t[lson(rson(root))].mark^=1;
     } 
    
    signed  main()
    {
    	int m;
    	read(n);read(m);
    	for(int i=1;i<=n+2;i++) insert(i);
    	for(int i=1,a,b;i<=m;i++)
    	{
    		read(a);
    		read(b);
    		solve(a,b);
    	}
    	print(root);	
    	printf("
    ");	
    	
    
     }
    

    (End) and 后记

    (WC)我听了个寂寞,……
    真就划呗……

  • 相关阅读:
    磁盘缓存
    算法与追MM(转)
    人人都能上清华(转)
    软件加密技术和注册机制原理攻略(转)
    计算二重定积分
    C++运算符重载
    STL中list的用法
    累了??放松一下,看几张关于程序员的几张搞笑图片
    解决来QQ消息后歌曲音量降低问题
    搞ACM的你伤不起(转)
  • 原文地址:https://www.cnblogs.com/-Iris-/p/14359248.html
Copyright © 2011-2022 走看看