zoukankan      html  css  js  c++  java
  • [学习笔记] 平衡树-Splay

    简介

    Splay是一种平衡二叉树。它通过不断地将某个节点旋转到根节点,使整棵树仍然满足二叉查找树的性质,并且保持平衡而不至于退化成链。

    Splay的时间复杂度是按总复杂度来算的,具体来说,即是:
    从空树开始,做插入、删除、访问操作共M次,树中最多同时存在N个点,
    则总时间复杂度不超过(O(MlogN))

    通常取平均值,表示为单次均摊(O(logN))

    复杂度采用了《算法导论》中的摊还分析,可以看这篇博客的证明:https://blog.csdn.net/qq_31640513/article/details/76944892

    属性

    • 二叉查找树的性质

      能够在这棵树上查找某个值的性质:左儿子的值(<)根节点的值(<)右儿子的值

    • 节点维护信息

      [egin{array}{llllll} hline r t & ext {tot} & f a[i] & operatorname{ch}[i][0 / 1] & v a l[i] & c n t[i] & s z[i] \ hline end{array} ]

      int rt;//根节点
      int tot;//节点个数
      struct node {
          int fa;//父亲节点
          int ch[2];//子节点
          int val;//权值
          int cnt;//权值出现次数
          int sz;//子树大小
      };
    

    这里为了表示每个节点的属性,采用了结构体的形式

    方法

    基本方法

    • maintain(x):在改变节点位置后,将节点(x)(size)更新
    • get(x):判断该节点是左儿子还是右儿子
    • Clear(x):销毁节点(x)
        //在改变节点位置后,将节点x的size更新
        inline void maintain(int x) {
            s[x].sz = s[s[x].ch[0]].sz+s[s[x].ch[1]].sz+s[x].cnt;
        }
    
        //判断该节点是左儿子还是右儿子
        inline bool get(int x) {return x == s[s[x].fa].ch[1];}
    
        //销毁节点x
        inline void Clear(int x) {
            s[x].ch[0] = s[x].ch[1] = s[x].fa = s[x].val = s[x].sz = s[x].cnt = 0;
        }
    

    旋转方法

    必须保证

    • 整棵树的中序遍历不变(不能破坏二叉查找树的性质)
    • 受影响的节点维护的信息依然正确有效。
    • root必须指向旋转后的根节点。

    旋转分为两种:左旋右旋

    具体步骤分析:

    设要旋转的点是(x)(x)的父亲是(y)(y)的父亲是(z)

    分三步:

    1. (y)(x)的子节点相连:如果(x)(y)的左儿子,那么(x)的右儿子与(y)相连
    2. (x)(y)父子相连
    3. (x)(y)的原来的父亲 (z)相连:如果(y)(z)的左儿子,那么(z)的左儿子与(x)相连

    Rotate(x)

    	inline void Rotate(int x) {
            int y = s[x].fa, z = s[y].fa, chk = get(x);
    
            //y与x的子节点相连
            s[y].ch[chk] = s[x].ch[chk ^ 1];
            s[s[x].ch[chk ^ 1]].fa = y;
    
            //x与y父子相连
            s[x].ch[chk ^ 1] = y;
            s[y].fa = x;
    
            // x与y的原来的父亲z相连
            s[x].fa = z;
            if(z) s[z].ch[y == s[z].ch[1]] = x;
    
            //只有x和y的sz变化了
            maintain(y);
            maintain(x);
        }
    

    splay方法

    每访问一个节点后都要强制将其旋转到根节点

    分六种情况:

    img

    • 如果(x)的父亲是根节点,直接将(x)左旋或右旋(图(1,2)
    • 如果(x)的父亲不是根节点,且(x)和它父亲的儿子类型(get(x)==get(f))相同,首先将其父亲左旋或右旋,然后将(x)右旋或左旋(图 (3,4)
    • 如果(x)的父亲不是根节点,且(x)和父亲的儿子类型不同,将(x)左旋再右旋、或者右旋再左旋(图 (5,6)

    splay(x):复杂的过程可以转为下面简单的代码

         //将当前节点转移到根节点
        inline void splay(int x) {
            for(int f = s[x].fa; f; Rotate(x),f = s[x].fa){
                if(s[f].fa) Rotate(get(x) == get(f) ? f : x);
            }
            rt=x;
        }
    

    因为对于当前(x)Rotate(x)旋转方式只有一种,如果是右儿子就左旋,左儿子就右旋,减少了很多思维上的麻烦,也就不用纠结该左还是右了。

    插入方法

    插入方法分三种情况

    • 如果树空了则直接插入根并退出
    • 如果原来权值存在,权值个数加一
    • 树中没有这个值,就新建节点

    一定要按照二叉查找树的性质遍历树

    找到了return前要进行splay()操作,来保证树的平衡

    ins(k)

    	//插入操作
        inline void ins(int k) {
            //如果树空了则直接插入根并退出
            if(!rt) {
                s[++tot].val = k;
                s[tot].cnt++;
                rt = tot;
                maintain(rt);
                return ;
            }
            int now = rt,f = 0;
            while(true) {
                //如果原来权值存在,权值个数加一
                if(s[now].val == k) {
                    s[now].cnt++;
                    maintain(now);
                    maintain(f);
                    splay(now);
                    break;
                }
                //按照二叉查找树的性质遍历树
                f = now;
                now = s[now].ch[s[now].val < k];
                //树中没有这个值,就新建节点
                if(!now) {
                    s[++tot].val = k;
                    s[tot].cnt++;
                    s[tot].fa = f;
                    s[f].ch[s[f].val < k] = tot;
                    maintain(tot);
                    maintain(f);
                    splay(tot);
                    break;
                }
            }
        }
    

    查询x的排名

    还是按照查找二叉树的性质进行查找

    • 如果(x)比当前节点的权值小,向其左子树查找。
    • 如果(x)比当前节点的权值大,将答案加上左子树$size (和当前节点)cnt$的大小,向其右子树查找。
    • 如果 与当前节点的权值相同,将答案加(1)并返回。

    Find(k)

        //查找某个数 返回这个数是第几个
        inline int Find(int k) {
            int res = 0,now = rt;
            while(true) {
                //如果这个数比当前节点小,搜索左子树
                if(k<s[now].val) {
                    now = s[now].ch[0];
                }else {
                    //否则加上右子树的个数
                    res += s[s[now].ch[0]].sz;
                    //中序遍历,如果找到这个节点返回res+1
                    if(k == s[now].val) {
                        splay(now);
                        return res + 1;
                    }
                    res += s[now].cnt;
                    now = s[now].ch[1];
                }
            }
        }
    

    查询排名x的数

    • 如果左子树非空且剩余排名(k)不大于左子树的大小 $ size $,那么向左子树查找。
    • 否则将(k)减去左子树的和根的大小。如果此时(k)的值小于等于(size),则返回根节点的权值,否则继续向右子树查找。

    getKth(k)

        //查询第k个数
        inline int getKth(int k) {
            int now = rt;
            while(true){
                if(s[now].ch[0] && k <= s[s[now].ch[0]].sz){
                    now = s[now].ch[0];
                }else{
                    k -= s[now].cnt + s[s[now].ch[0]].sz;
                    if(k <= 0){
                        splay(now);
                        return s[now].val;
                    }
                    now=s[now].ch[1];
                }
            }
        }
    

    查询前驱和后继

    getPre():查询小于x的最大的数的节点,就是找左儿子的右链

    getNxt():查询大于x的最小的数的节点,就是找右儿子的左链

        //查询小于x的最大的数的节点,就是找左儿子的右链
        inline int getPre() {
            int now = s[rt].ch[0];
            while (s[now].ch[1]) now = s[now].ch[1];
            return now;
        }
    
        //查询大于x的最小的数的节点,同理
        inline int getNxt() {
            int now = s[rt].ch[1];
            while (s[now].ch[0]) now = s[now].ch[0];
            return now;
        }
    

    删除方法

    删除方法具体步骤:

    1. 首先将(x)旋转到根的位置,要用到Find(x)先找到(x)
    2. 如果大于(1),则不需要删除节点,只需要将(cnt-1)
    3. 如果只有一个点,删除这个点之后,将rt变为(0)
    4. 如果左右一个儿子,就将该点删除,并让那一个儿子成为根节点
    5. 否则就将(x)的前驱旋转到根节点,并将(x)的右儿子与根节点相连,将(x)删除。

    del(x)

        inline void del(int k){
            Find(k);//先让该点成为根节点
            if(s[rt].cnt > 1) {//如果大于1,不需要删除节点
                s[rt].cnt--;
                maintain(rt);
                return;
            }
            //如果只有一个点
            if(!s[rt].ch[0] && !s[rt].ch[1]){
                Clear(rt);
                rt = 0;
                return;
            }
            //没有左儿子,让右儿子成为根节点
            if(!s[rt].ch[0]){
                int tmp = rt;
                rt = s[rt].ch[1];
                s[rt].fa=0;
                Clear(tmp);
                return;
            }
            //没有右儿子,让左儿子成为根节点
            if(!s[rt].ch[1]){
                int tmp = rt;
                rt = s[rt].ch[0];
                s[rt].fa = 0;
                Clear(tmp);
                return;
            }
            //有左右儿子,让前驱成为根节点
            int x = getPre() , now = rt;
            splay(x);
            s[s[now].ch[1]].fa = x;
            s[x].ch[1] = s[now].ch[1];
            Clear(now);
            maintain(rt);
        }
    

    模板题

    https://www.luogu.com.cn/problem/P3369

    完整代码:

    #include<bits/stdc++.h>
    using namespace std;
    const int N = 1e5+7;
    
    int rt;//根节点
    int tot;//节点个数
    struct node {
        int fa;//父亲节点
        int ch[2];//子节点
        int val;//权值
        int cnt;//权值出现次数
        int sz;//子树大小
    }s[N];
    
    struct Splay {
    
        //在改变节点位置后,将节点x的size更新
        inline void maintain(int x) {
            s[x].sz = s[s[x].ch[0]].sz+s[s[x].ch[1]].sz+s[x].cnt;
        }
    
        //判断该节点是左儿子还是右儿子
        inline bool get(int x) {return x == s[s[x].fa].ch[1];}
    
        //销毁节点x
        inline void Clear(int x) {
            s[x].ch[0] = s[x].ch[1] = s[x].fa = s[x].val = s[x].sz = s[x].cnt = 0;
        }
    
        inline void Rotate(int x) {
            int y = s[x].fa, z = s[y].fa, chk = get(x);
    
            //y与x的子节点相连
            s[y].ch[chk] = s[x].ch[chk ^ 1];
            s[s[x].ch[chk ^ 1]].fa = y;
    
            //x与y父子相连
            s[x].ch[chk ^ 1] = y;
            s[y].fa = x;
    
            // x与y的原来的父亲z相连
            s[x].fa = z;
            if(z) s[z].ch[y == s[z].ch[1]] = x;
    
            //只有x和y的sz变化了
            maintain(y);
            maintain(x);
        }
        //将当前节点转移到根节点
        inline void splay(int x) {
    
            for(int f = s[x].fa; f; Rotate(x),f = s[x].fa){
                if(s[f].fa) Rotate(get(x) == get(f) ? f : x);
            }
            rt=x;
        }
        //插入操作
        inline void ins(int k) {
            //如果树空了则直接插入根并退出
            if(!rt) {
                s[++tot].val = k;
                s[tot].cnt++;
                rt = tot;
                maintain(rt);
                return ;
            }
            int now = rt,f = 0;
            while(true) {
                //如果原来权值存在,权值个数加一
                if(s[now].val == k) {
                    s[now].cnt++;
                    maintain(now);
                    maintain(f);
                    splay(now);
                    break;
                }
                //按照二叉查找树的性质遍历树
                f = now;
                now = s[now].ch[s[now].val < k];
                //树中没有这个值,就新建节点
                if(!now) {
                    s[++tot].val = k;
                    s[tot].cnt++;
                    s[tot].fa = f;
                    s[f].ch[s[f].val < k] = tot;
                    maintain(tot);
                    maintain(f);
                    splay(tot);
                    break;
                }
            }
        }
        //查找某个数 返回这个数是第几个
        inline int Find(int k) {
            int res = 0,now = rt;
            while(true) {
                //如果这个数比当前节点小,搜索左子树
                if(k<s[now].val) {
                    now = s[now].ch[0];
                }else {
                    //否则加上右子树的个数
                    res += s[s[now].ch[0]].sz;
                    //中序遍历,如果找到这个节点返回res+1
                    if(k == s[now].val) {
                        splay(now);
                        return res + 1;
                    }
                    res += s[now].cnt;
                    now = s[now].ch[1];
                }
            }
        }
    
        //查询小于x的最大的数的节点,就是找左儿子的右链
        inline int getPre() {
            int now = s[rt].ch[0];
            while (s[now].ch[1]) now = s[now].ch[1];
            return now;
        }
    
        //查询大于x的最小的数的节点,同理
        inline int getNxt() {
            int now = s[rt].ch[1];
            while (s[now].ch[0]) now = s[now].ch[0];
            return now;
        }
    
        //查询第k个数
        inline int getKth(int k) {
            int now = rt;
            while(true){
                if(s[now].ch[0] && k <= s[s[now].ch[0]].sz){
                    now = s[now].ch[0];
                }else{
                    k -= s[now].cnt + s[s[now].ch[0]].sz;
                    if(k <= 0){
                        splay(now);
                        return s[now].val;
                    }
                    now=s[now].ch[1];
                }
            }
        }
    
        //删除结点
        inline void del(int k){
            Find(k);//先让该点成为根节点
            if(s[rt].cnt > 1) {//如果大于1,不需要删除节点
                s[rt].cnt--;
                maintain(rt);
                return;
            }
            //如果只有一个点
            if(!s[rt].ch[0] && !s[rt].ch[1]){
                Clear(rt);
                rt = 0;
                return;
            }
            //没有左儿子,让右儿子成为根节点
            if(!s[rt].ch[0]){
                int tmp = rt;
                rt = s[rt].ch[1];
                s[rt].fa=0;
                Clear(tmp);
                return;
            }
            //没有右儿子,让左儿子成为根节点
            if(!s[rt].ch[1]){
                int tmp = rt;
                rt = s[rt].ch[0];
                s[rt].fa = 0;
                Clear(tmp);
                return;
            }
            //有左右儿子,让前驱成为根节点
            int x = getPre() , now = rt;
            splay(x);
            s[s[now].ch[1]].fa = x;
            s[x].ch[1] = s[now].ch[1];
            Clear(now);
            maintain(rt);
        }
    }st;
    int main(){
        int n,opt,x;
        scanf("%d",&n);
        while(n--){
            scanf("%d%d",&opt,&x);
            if(opt == 1) st.ins(x);
            else if(opt == 2) st.del(x);
            else if(opt == 3) printf("%d
    ",st.Find(x));
            else if(opt == 4) printf("%d
    ",st.getKth(x));
            else if(opt == 5) {
                st.ins(x);
                printf("%d
    ",s[st.getPre()].val);
                st.del(x);
            }
            else {
                st.ins(x);
                printf("%d
    ",s[st.getNxt()].val);
                st.del(x);
            }
        }
        return 0;
    }
    
    

    oiwiki上的代码没有用结构体写起来比较快:

    #include <cstdio>
    const int N = 100005;
    int rt, tot, fa[N], ch[N][2], val[N], cnt[N], sz[N];
    struct Splay {
      void maintain(int x) { sz[x] = sz[ch[x][0]] + sz[ch[x][1]] + cnt[x]; }
      bool get(int x) { return x == ch[fa[x]][1]; }
      void clear(int x) {
        ch[x][0] = ch[x][1] = fa[x] = val[x] = sz[x] = cnt[x] = 0;
      }
      void rotate(int x) {
        int y = fa[x], z = fa[y], chk = get(x);
        ch[y][chk] = ch[x][chk ^ 1];
        fa[ch[x][chk ^ 1]] = y;
        ch[x][chk ^ 1] = y;
        fa[y] = x;
        fa[x] = z;
        if (z) ch[z][y == ch[z][1]] = x;
        maintain(x);
        maintain(y);
      }
      void splay(int x) {
        for (int f = fa[x]; f = fa[x], f; rotate(x))
          if (fa[f]) rotate(get(x) == get(f) ? f : x);
        rt = x;
      }
      void ins(int k) {
        if (!rt) {
          val[++tot] = k;
          cnt[tot]++;
          rt = tot;
          maintain(rt);
          return;
        }
        int cnr = rt, f = 0;
        while (1) {
          if (val[cnr] == k) {
            cnt[cnr]++;
            maintain(cnr);
            maintain(f);
            splay(cnr);
            break;
          }
          f = cnr;
          cnr = ch[cnr][val[cnr] < k];
          if (!cnr) {
            val[++tot] = k;
            cnt[tot]++;
            fa[tot] = f;
            ch[f][val[f] < k] = tot;
            maintain(tot);
            maintain(f);
            splay(tot);
            break;
          }
        }
      }
      int rk(int k) {
        int res = 0, cnr = rt;
        while (1) {
          if (k < val[cnr]) {
            cnr = ch[cnr][0];
          } else {
            res += sz[ch[cnr][0]];
            if (k == val[cnr]) {
              splay(cnr);
              return res + 1;
            }
            res += cnt[cnr];
            cnr = ch[cnr][1];
          }
        }
      }
      int kth(int k) {
        int cnr = rt;
        while (1) {
          if (ch[cnr][0] && k <= sz[ch[cnr][0]]) {
            cnr = ch[cnr][0];
          } else {
            k -= cnt[cnr] + sz[ch[cnr][0]];
            if (k <= 0) return val[cnr];
            cnr = ch[cnr][1];
          }
        }
      }
      int pre() {
        int cnr = ch[rt][0];
        while (ch[cnr][1]) cnr = ch[cnr][1];
        return cnr;
      }
      int nxt() {
        int cnr = ch[rt][1];
        while (ch[cnr][0]) cnr = ch[cnr][0];
        return cnr;
      }
      void del(int k) {
        rk(k);
        if (cnt[rt] > 1) {
          cnt[rt]--;
          maintain(rt);
          return;
        }
        if (!ch[rt][0] && !ch[rt][1]) {
          clear(rt);
          rt = 0;
          return;
        }
        if (!ch[rt][0]) {
          int cnr = rt;
          rt = ch[rt][1];
          fa[rt] = 0;
          clear(cnr);
          return;
        }
        if (!ch[rt][1]) {
          int cnr = rt;
          rt = ch[rt][0];
          fa[rt] = 0;
          clear(cnr);
          return;
        }
        int x = pre(), cnr = rt;
        splay(x);
        fa[ch[cnr][1]] = x;
        ch[x][1] = ch[cnr][1];
        clear(cnr);
        maintain(rt);
      }
    } tree;
    
    int main() {
      int n, opt, x;
      for (scanf("%d", &n); n; --n) {
        scanf("%d%d", &opt, &x);
        if (opt == 1)
          tree.ins(x);
        else if (opt == 2)
          tree.del(x);
        else if (opt == 3)
          printf("%d
    ", tree.rk(x));
        else if (opt == 4)
          printf("%d
    ", tree.kth(x));
        else if (opt == 5)
          tree.ins(x), printf("%d
    ", val[tree.pre()]), tree.del(x);
        else
          tree.ins(x), printf("%d
    ", val[tree.nxt()]), tree.del(x);
      }
      return 0;
    }
    

    来源

    [1] https://oi-wiki.org/ds/splay/

  • 相关阅读:
    Connected Graph
    Gerald and Giant Chess
    [NOI2009]诗人小G
    四边形不等式小结
    [NOI2007]货币兑换
    Cats Transport
    Cut the Sequence
    Fence
    The Battle of Chibi
    [Usaco2005 Dec]Cleaning Shifts
  • 原文地址:https://www.cnblogs.com/smallocean/p/12410955.html
Copyright © 2011-2022 走看看