zoukankan      html  css  js  c++  java
  • 平衡树之splay总结

    前置芝士:

    平衡树:可以自平衡的二叉排序树,任然具有 左儿子<父亲<右儿子 的特点,且可保证不会退化成链,保证时间复杂度为(nlogn)

    旋转:我的splay中只存在上旋(即将某个节点向上旋转),不区分左旋和右旋

    前驱:比某个数小的最大数

    后驱:比某个数大的最小数

    平衡树的定义:

    ll root=0,decnt=0;//root表示splay的根节点 decnt代表新建节点编号 
    ll ch[maxn][2],size[maxn],cnt[maxn],val[maxn],prt[maxn],rev[maxn];
    //ch[v][0]表示v的左儿子 ch[v][1]表示v的右儿子 prt[v]表示v的父亲 
    //val[v]表示v的值 size[v]表示以v为根节点的子树的节点总数 cnt[v]表示值为val[v]的点的个数
    //rev[v]==1时代表要区间翻转 rev[v]==0时表示不需要区间翻转 
    View Code

    更新:

    void pushup(ll v){size[v]=size[ch[v][0]]+size[ch[v][1]]+cnt[v];}
    View Code

    然后是所有平衡树都会用到的旋转操作:

    void rotate(ll v){
        ll y=prt[v],z=prt[y],d=chk(v),k=ch[v][d^1];
        ch[y][d]=k;prt[k]=y;
        ch[z][chk(y)]=v;prt[v]=z;
        ch[v][d^1]=y;prt[y]=v;
        pushup(y),pushup(v);
    }
    View Code

    接下来就是splay的核心操作------splay操作  本质就是把一个节点旋到某个节点的儿子处(默认为0的儿子,即旋到根节点):

    void splay(ll cur,ll v=0){
        while(prt[cur]!=v){
            ll pr=prt[cur];
            if(prt[pr]!=v){
                if(chk(cur)==chk(pr))rotate(pr);
                else rotate(cur);
            }
            rotate(cur);
        }
        if(!v)root=cur;
    }
    View Code

    插入操作

    void insert(ll x){
        ll cur=root,p=0;
        while(cur&&x!=val[cur])p=cur,cur=ch[cur][x>val[cur]];
        if(cur)cnt[cur]++;
        else{
            cur=++decnt;
            if(p)ch[p][x>val[p]]=cur;
            ch[cur][0]=ch[cur][1]=0;
            val[cur]=x;prt[cur]=p;
            size[cur]=cnt[cur]=1;
        }
        splay(cur);
    }
    View Code

    查找操作,即找到某个节点并把他旋转到根节点

    void find(ll x){
        ll cur=root;
        while(ch[cur][x>val[cur]]&&x!=val[cur])cur=ch[cur][x>val[cur]];
        splay(cur);
    }
    View Code

    求某个树的排名,只需要把他旋转到根节点,排名就是他的左子树的节点数+1

    ll rank(ll x){
        find(x);return size[ch[root][0]];
        //本来应该是返回 size[ch[root][0]]+1但为了避免溢出,我先insert了inf和-inf,所以排名就应该-1 
    }
    View Code

    求第k大(调用时应该是kth(k+1),原因同上)

    ll kth(ll k){
        ll cur=root;
        while(true){
            pushdown(cur);
            if(ch[cur][0]&&size[ch[cur][0]]>=k)cur=ch[cur][0];
            else if(ch[cur][1]&&size[ch[cur][0]]+cnt[cur]<k)k-=size[ch[cur][0]]+cnt[cur],cur=ch[cur][1];
            else return cur;
        }
        return cur;
    }
    View Code

    求前驱,把这个数旋到根,并在左子树中找最大值

    ll pre(ll x){
        find(x);
        if(val[root]<x)return root;//特判一下,防止出现查找不存在的数的情况 
        ll cur=ch[root][0];
        while(ch[cur][1])cur=ch[cur][1];
        return cur;
    }
    View Code

    求后驱,把这个数旋到根,并在右子树中找最小值

    ll succ(ll x){
        find(x);
        if(val[root]>x)return root;//特判一下,防止出现查找不存在的数的情况
        ll cur=ch[root][1];
        while(ch[cur][0])cur=ch[cur][0];
        return cur;
    }
    View Code

    删除某个数,只需要把他的前驱旋转到根,把他的后驱旋转到根的左儿子,因为大于他的前驱,所以他在根的右子树,又因为他小于后驱且除他之外没有小于后驱而大于前驱的数,所以他的后驱的左子树只有他一个节点

    void remove(ll x){
        ll lst=pre(x),nxt=succ(x);
        splay(lst),splay(nxt,lst);
        ll del=ch[nxt][0];
        if(cnt[del]>1){
            cnt[del]--;splay(del);
        }else ch[nxt][0]=0;
    }
    View Code

    区间翻转,打标记就好了

    void reverse(ll l,ll r){   //这个只在所有节点编号为1~n的时候能用
        ll x=kth(l),y=kth(r+2);
        splay(x),splay(y,x);
        rev[ch[y][0]]^=1;
    }
    View Code

    输出序列,就中序遍历一遍就好了

    void print(ll v){
        if(!v)return;
        pushdown(v);
        print(ch[v][0]);
        if(val[v]!=inf&&val[v]!=inf)for(ll i=1;i<=cnt[v];i++)printf("%lld ",val[v]);
        print(ch[v][1]);
    }
    View Code

    综上,splay的代码如下

    namespace splay{
        const ll inf=1ll<<30;
        const ll maxn=200010;
        ll root=0,decnt=0;//root表示splay的根节点 decnt代表新建节点编号 
        ll ch[maxn][2],size[maxn],cnt[maxn],val[maxn],prt[maxn],rev[maxn];
        //ch[v][0]表示v的左儿子 ch[v][1]表示v的右儿子 prt[v]表示v的父亲 
        //val[v]表示v的值 size[v]表示以v为根节点的子树的节点总数 cnt[v]表示值为val[v]的点的个数
        //rev[v]==1时代表要区间翻转 rev[v]==0时表示不需要区间翻转 
        ll chk(ll v){return ch[prt[v]][1]==v;}
        void swap(ll &a,ll &b){a^=b^=a^=b;}
        void pushup(ll v){size[v]=size[ch[v][0]]+size[ch[v][1]]+cnt[v];}
        void pushdown(ll v){
            if(rev[v]){
                swap(ch[v][0],ch[v][1]);
                rev[ch[v][0]]^=1,rev[ch[v][1]]^=1;
                rev[v]=0;
            }
        }
        void rotate(ll v){
            ll y=prt[v],z=prt[y],d=chk(v),k=ch[v][d^1];
            ch[y][d]=k;prt[k]=y;
            ch[z][chk(y)]=v;prt[v]=z;
            ch[v][d^1]=y;prt[y]=v;
            pushup(y),pushup(v);
        }
        void splay(ll cur,ll v=0){
            while(prt[cur]!=v){
                ll pr=prt[cur];
                if(prt[pr]!=v){
                    if(chk(cur)==chk(pr))rotate(pr);
                    else rotate(cur);
                }
                rotate(cur);
            }
            if(!v)root=cur;
        }
        void insert(ll x){
            ll cur=root,p=0;
            while(cur&&x!=val[cur])p=cur,cur=ch[cur][x>val[cur]];
            if(cur)cnt[cur]++;
            else{
                cur=++decnt;
                if(p)ch[p][x>val[p]]=cur;
                ch[cur][0]=ch[cur][1]=0;
                val[cur]=x;prt[cur]=p;
                size[cur]=cnt[cur]=1;
            }
            splay(cur);
        }
        void find(ll x){
            ll cur=root;
            while(ch[cur][x>val[cur]]&&x!=val[cur])cur=ch[cur][x>val[cur]];
            splay(cur);
        }
        ll rank(ll x){
            find(x);return size[ch[root][0]];
            //本来应该是返回 size[ch[root][0]]+1但为了避免溢出,我先insert了inf和-inf,所以排名就应该-1 
        }
        ll kth(ll k){
            ll cur=root;
            while(true){
                pushdown(cur);
                if(ch[cur][0]&&size[ch[cur][0]]>=k)cur=ch[cur][0];
                else if(ch[cur][1]&&size[ch[cur][0]]+cnt[cur]<k)k-=size[ch[cur][0]]+cnt[cur],cur=ch[cur][1];
                else return cur;
            }
            return cur;
        }
        ll pre(ll x){
            find(x);
            if(val[root]<x)return root;//特判一下,防止出现查找不存在的数的情况 
            ll cur=ch[root][0];
            while(ch[cur][1])cur=ch[cur][1];
            return cur;
        }
        ll succ(ll x){
            find(x);
            if(val[root]>x)return root;//特判一下,防止出现查找不存在的数的情况
            ll cur=ch[root][1];
            while(ch[cur][0])cur=ch[cur][0];
            return cur;
        }
        void remove(ll x){
            ll lst=pre(x),nxt=succ(x);
            splay(lst),splay(nxt,lst);
            ll del=ch[nxt][0];
            if(cnt[del]>1){
                cnt[del]--;splay(del);
            }else ch[nxt][0]=0;
        }
        void reverse(ll l,ll r){
            ll x=kth(l),y=kth(r+2);
            splay(x),splay(y,x);
            rev[ch[y][0]]^=1;
        }
        void print(ll v){
            if(!v)return;
            pushdown(v);
            print(ch[v][0]);
            if(val[v]!=inf&&val[v]!=inf)for(ll i=1;i<=cnt[v];i++)printf("%lld ",val[v]);
            print(ch[v][1]);
        }
    }
    View Code

    Luogu P3369 【模板】普通平衡树

    #include<cstdio>
    #define ll long long
    namespace splay{
        const ll inf=1ll<<30;
        const ll maxn=200010;
        ll root=0,decnt=0;//root表示splay的根节点 decnt代表新建节点编号 
        ll ch[maxn][2],size[maxn],cnt[maxn],val[maxn],prt[maxn],rev[maxn];
        //ch[v][0]表示v的左儿子 ch[v][1]表示v的右儿子 prt[v]表示v的父亲 
        //val[v]表示v的值 size[v]表示以v为根节点的子树的节点总数 cnt[v]表示值为val[v]的点的个数
        //rev[v]==1时代表要区间翻转 rev[v]==0时表示不需要区间翻转 
        ll chk(ll v){return ch[prt[v]][1]==v;}
        void swap(ll &a,ll &b){a^=b^=a^=b;}
        void pushup(ll v){size[v]=size[ch[v][0]]+size[ch[v][1]]+cnt[v];}
        void pushdown(ll v){
            if(rev[v]){
                swap(ch[v][0],ch[v][1]);
                rev[ch[v][0]]^=1,rev[ch[v][1]]^=1;
                rev[v]=0;
            }
        }
        void rotate(ll v){
            ll y=prt[v],z=prt[y],d=chk(v),k=ch[v][d^1];
            ch[y][d]=k;prt[k]=y;
            ch[z][chk(y)]=v;prt[v]=z;
            ch[v][d^1]=y;prt[y]=v;
            pushup(y),pushup(v);
        }
        void splay(ll cur,ll v=0){
            while(prt[cur]!=v){
                ll pr=prt[cur];
                if(prt[pr]!=v){
                    if(chk(cur)==chk(pr))rotate(pr);
                    else rotate(cur);
                }
                rotate(cur);
            }
            if(!v)root=cur;
        }
        void insert(ll x){
            ll cur=root,p=0;
            while(cur&&x!=val[cur])p=cur,cur=ch[cur][x>val[cur]];
            if(cur)cnt[cur]++;
            else{
                cur=++decnt;
                if(p)ch[p][x>val[p]]=cur;
                ch[cur][0]=ch[cur][1]=0;
                val[cur]=x;prt[cur]=p;
                size[cur]=cnt[cur]=1;
            }
            splay(cur);
        }
        void find(ll x){
            ll cur=root;
            while(ch[cur][x>val[cur]]&&x!=val[cur])cur=ch[cur][x>val[cur]];
            splay(cur);
        }
        ll rank(ll x){
            find(x);return size[ch[root][0]];
            //本来应该是返回 size[ch[root][0]]+1但为了避免溢出,我先insert了inf和-inf,所以排名就应该-1 
        }
        ll kth(ll k){
            ll cur=root;
            while(true){
                pushdown(cur);
                if(ch[cur][0]&&size[ch[cur][0]]>=k)cur=ch[cur][0];
                else if(ch[cur][1]&&size[ch[cur][0]]+cnt[cur]<k)k-=size[ch[cur][0]]+cnt[cur],cur=ch[cur][1];
                else return cur;
            }
            return cur;
        }
        ll pre(ll x){
            find(x);
            if(val[root]<x)return root;//特判一下,防止出现查找不存在的数的情况 
            ll cur=ch[root][0];
            while(ch[cur][1])cur=ch[cur][1];
            return cur;
        }
        ll succ(ll x){
            find(x);
            if(val[root]>x)return root;//特判一下,防止出现查找不存在的数的情况
            ll cur=ch[root][1];
            while(ch[cur][0])cur=ch[cur][0];
            return cur;
        }
        void remove(ll x){
            ll lst=pre(x),nxt=succ(x);
            splay(lst),splay(nxt,lst);
            ll del=ch[nxt][0];
            if(cnt[del]>1){
                cnt[del]--;splay(del);
            }else ch[nxt][0]=0;
        }
        void reverse(ll l,ll r){
            ll x=kth(l),y=kth(r+2);
            splay(x),splay(y,x);
            rev[ch[y][0]]^=1;
        }
        void print(ll v){
            if(!v)return;
            pushdown(v);
            print(ch[v][0]);
            if(val[v]!=inf&&val[v]!=-inf)for(ll i=1;i<=cnt[v];i++)printf("%lld ",val[v]);
            print(ch[v][1]);
        }
    }
    using namespace splay;
    ll n;
    int main(){
        scanf("%lld",&n);
        insert(inf);
        insert(-inf);
        while(n--){
            ll opt,x;
            scanf("%lld%lld",&opt,&x);
            switch(opt){
                case 1:{insert(x);break;}
                case 2:{remove(x);break;}
                case 3:{printf("%lld
    ",rank(x));break;}
                case 4:{printf("%lld
    ",val[kth(x+1)]);break;}
                case 5:{printf("%lld
    ",val[pre(x)]);break;}
                case 6:{printf("%lld
    ",val[succ(x)]);break;}
            }
        }
        return 0;
    }
    View Code

    Luogu P3391 【模板】文艺平衡树(Splay)

    #include<cstdio>
    #define ll long long
    ll n,m;
    namespace splay{
        const ll inf=1ll<<30;
        const ll maxn=200010;
        ll root=0,decnt=0;//root表示splay的根节点 decnt代表新建节点编号 
        ll ch[maxn][2],size[maxn],cnt[maxn],val[maxn],prt[maxn],rev[maxn];
        //ch[v][0]表示v的左儿子 ch[v][1]表示v的右儿子 prt[v]表示v的父亲 
        //val[v]表示v的值 size[v]表示以v为根节点的子树的节点总数 cnt[v]表示值为val[v]的点的个数
        //rev[v]==1时代表要区间翻转 rev[v]==0时表示不需要区间翻转 
        ll chk(ll v){return ch[prt[v]][1]==v;}
        void swap(ll &a,ll &b){a^=b^=a^=b;}
        void pushup(ll v){size[v]=size[ch[v][0]]+size[ch[v][1]]+cnt[v];}
        void pushdown(ll v){
            if(rev[v]){
                swap(ch[v][0],ch[v][1]);
                rev[ch[v][0]]^=1,rev[ch[v][1]]^=1;
                rev[v]=0;
            }
        }
        void rotate(ll v){
            ll y=prt[v],z=prt[y],d=chk(v),k=ch[v][d^1];
            ch[y][d]=k;prt[k]=y;
            ch[z][chk(y)]=v;prt[v]=z;
            ch[v][d^1]=y;prt[y]=v;
            pushup(y),pushup(v);
        }
        void splay(ll cur,ll v=0){
            while(prt[cur]!=v){
                ll pr=prt[cur];
                if(prt[pr]!=v){
                    if(chk(cur)==chk(pr))rotate(pr);
                    else rotate(cur);
                }
                rotate(cur);
            }
            if(!v)root=cur;
        }
        void insert(ll x){
            ll cur=root,p=0;
            while(cur&&x!=val[cur])p=cur,cur=ch[cur][x>val[cur]];
            if(cur)cnt[cur]++;
            else{
                cur=++decnt;
                if(p)ch[p][x>val[p]]=cur;
                ch[cur][0]=ch[cur][1]=0;
                val[cur]=x;prt[cur]=p;
                size[cur]=cnt[cur]=1;
            }
            splay(cur);
        }
        void find(ll x){
            ll cur=root;
            while(ch[cur][x>val[cur]]&&x!=val[cur])cur=ch[cur][x>val[cur]];
            splay(cur);
        }
        ll rank(ll x){
            find(x);return size[ch[root][0]];
            //本来应该是返回 size[ch[root][0]]+1但为了避免溢出,我先insert了inf和-inf,所以排名就应该-1 
        }
        ll kth(ll k){
            ll cur=root;
            while(true){
                pushdown(cur);
                if(ch[cur][0]&&size[ch[cur][0]]>=k)cur=ch[cur][0];
                else if(ch[cur][1]&&size[ch[cur][0]]+cnt[cur]<k)k-=size[ch[cur][0]]+cnt[cur],cur=ch[cur][1];
                else return cur;
            }
            return cur;
        }
        ll pre(ll x){
            find(x);
            if(val[root]<x)return root;//特判一下,防止出现查找不存在的数的情况 
            ll cur=ch[root][0];
            while(ch[cur][1])cur=ch[cur][1];
            return cur;
        }
        ll succ(ll x){
            find(x);
            if(val[root]>x)return root;//特判一下,防止出现查找不存在的数的情况
            ll cur=ch[root][1];
            while(ch[cur][0])cur=ch[cur][0];
            return cur;
        }
        void remove(ll x){
            ll lst=pre(x),nxt=succ(x);
            splay(lst),splay(nxt,lst);
            ll del=ch[nxt][0];
            if(cnt[del]>1){
                cnt[del]--;splay(del);
            }else ch[nxt][0]=0;
        }
        void reverse(ll l,ll r){
            ll x=kth(l),y=kth(r+2);
            splay(x),splay(y,x);
            rev[ch[y][0]]^=1;
        }
        void print(ll v){
            if(!v)return;
            pushdown(v);
            print(ch[v][0]);
            if(val[v]!=inf&&val[v]!=-inf)for(ll i=1;i<=cnt[v];i++)printf("%lld ",val[v]);
            print(ch[v][1]);
        }
    }
    using namespace splay;
    int main(){
        scanf("%lld%lld",&n,&m);
        insert(inf);
        insert(-inf);
        for(ll i=1;i<=n;i++)insert(i);
        while(m--){
            ll l,r;
            scanf("%lld%lld",&l,&r);
            reverse(l,r);
        }
        print(root);
        return 0;
    }
    View Code
  • 相关阅读:
    算法的学习 — 冒泡排序
    自定义UICollectionLayout布局 —— UIKit之学习UICollectionView记录一《瀑布流》
    HDU 1541 Stars (线段树||树状数组)
    HDU 1617 Phone List (排序||字典树)
    CSU 1312 CX and girls (最短路)
    CSU 1320 Scoop water (卡特兰数)
    POJ 1838 Banana (并查集)
    POJ 1837 Balance (DP)
    POJ 1088 滑雪 (记忆化搜索)
    TYVJ 1261 可达总数 (BFS)
  • 原文地址:https://www.cnblogs.com/Railgunforever/p/10122750.html
Copyright © 2011-2022 走看看