zoukankan      html  css  js  c++  java
  • 神奇的splay树

    神奇的splay树

    总结

    1. splay树是一种BST,其通过不断的splay操作维持树的平衡;其基本思想是将频率高的点(实际是每次查找的点)通过splay操作旋转到树根
    2. 核心操作:
    • update(x): 维护信息,类似线段树中的push_up

    • rotate(x): 单旋,即将x旋转到其父节点y的位置,需要注意顺序(替换y,x的子树加入y, y最为x的子树)

    • splay(int x,int s): 将x节点旋转到s下方。情况1:x,y,z共线,先rotate(y),再rotate(x); 情况2:不共线,rotate(x) 两次

    • find(int x): 找到后需旋转到根

    • insert(int x): 找到计数++,否则产生新节点并初始化

    • Next(int x,int f): 寻找前驱和后继

    • Delete(int x):先找前驱和后继,然后将前驱旋转到根节点root,后继旋转到root下面,然后删掉x(此时在t[nxt].ch[0])

    • kth(int x): 比较size即可

    • tarjan搞的算法都很难写,而且很容易写错呀,还不好调试吧

    • reference: blog1 blog2 blog3

    模板题luogu3369

    代码

    #include <bits/stdc++.h>
    using namespace std;
    const int N=201000;
    struct splay_tree
    {
       int ff,cnt,ch[2],val,size;
    } t[N];
    int root,tot;//root==0 表示是空树 根节点的ff为0
    void update(int x)//更新节点x
    {
       t[x].size = t[t[x].ch[0]].size+t[t[x].ch[1]].size+t[x].cnt;
       
    }
    void rotate(int x)//对x进行单旋
    {
       int y = t[x].ff; int z =t[y].ff;
       int k = (t[y].ch[1]==x);
       t[z].ch[t[z].ch[1]==y] = x;// 用x替换z节点的儿子节点y
       t[x].ff = z;
       t[y].ch[k] = t[x].ch[1^k];  //先把 x的子树移到y
       t[t[x].ch[1^k]].ff = y;
       t[x].ch[1^k] = y; //y-x 与 x-y关系相反,构建x-y关系
       t[y].ff = x;
       update(y);update(x);// 先更新下面的层
       
    }
    void splay(int x,int s)//将x 旋转到 s下方, s==0 则是旋转到根
    {
       while(t[x].ff!=s)
       {
           int y=t[x].ff,z=t[y].ff;
           if (z!=s)//z==s 意味只需旋转一下x即可
               (t[z].ch[0]==y)^(t[y].ch[0]==x)?rotate(x):rotate(y);//如果 x,z,y同线,先旋转y,再旋转x;否则旋转两次x
           rotate(x);
       }
       if (s==0) //s==0 x旋转到根,更新root
           root=x;
    }
    void find(int x)
    {
       int u=root;
       if (!u)
           return ;//空树
       while(t[u].ch[x>t[u].val] && x!=t[u].val)//x>t[u].val 向右找, x< t[u].val 向左找
           u=t[u].ch[x>t[u].val];
       //也有可能找不到
       splay(u,0);//找到x,将其splay到根
    }
    void insert(int x) //插入操作
    {
       int u=root,ff=0;
       while(u && t[u].val!=x)
       {
           ff=u;
           u=t[u].ch[x>t[u].val];
       }
       if (u)//找到元素x的节点,计数++
           t[u].cnt++;
       else//没有找到则产生新节点
       {
           u=++tot;
           if (ff)//ff!=0 u不是根节点
               t[ff].ch[x>t[ff].val]=u;
           t[u].ch[0]=t[u].ch[1]=0;//初始化t[u]
           t[tot].ff=ff;
           t[tot].val=x;
           t[tot].cnt=1;
           t[tot].size=1;
       }
       splay(u,0);//u splay到根节点
    }
    int Next(int x,int f)//f=0 表示前驱 f=1表示后继
    {
       find(x);// 如果找到x所在节点 会被splay到根节点
       int u=root;
       if (t[u].val>x && f) //find没有找到x
           return u;
       if (t[u].val<x && !f)
           return u;
       //find 找打了x,且此时再根节点上
       u=t[u].ch[f];
       while(t[u].ch[f^1])//左子树的最右边节点/右子树的最左边节点
           u=t[u].ch[f^1];
       return u;
    }
    void Delete(int x)
    {
       int last=Next(x,0);
       int Net=Next(x,1);
       splay(last,0);
       splay(Net,last);  //找到前驱和后继并将前驱splay到根节点,后继splay到根节点下面; 则x代表的节点在是根节点的左儿子
       int del=t[Net].ch[0];
       if (t[del].cnt>1)//计数--
       {
           t[del].cnt--;
           splay(del,0);
       }
       else
           t[Net].ch[0]=0;//彻底删掉
    }
    int kth(int x)
    {
       int u=root;
       while(t[u].size<x)//不存在排名为x
           return 0;
       while(1)
       {
           int y=t[u].ch[0];
           if (x>t[y].size+t[u].cnt) //在右子树
           {
               x-=t[y].size+t[u].cnt;
               u=t[u].ch[1];
           }
           else if (t[y].size>=x)//在左子树
               u=y;
           else //就在u
               return t[u].val;
       }
    }
    
    /*
    插入数值x。
    删除数值x(若有多个相同的数,应只删除一个)。
    查询数值x的排名(若有多个相同的数,应输出最小的排名)。
    查询排名为x的数值。
    求数值x的前驱(前驱定义为小于x的最大的数)。
    求数值x的后继(后继定义为大于x的最小的数)。
    */
    int main()
    {
       int n;
       scanf("%d",&n);
       insert(1e9); //始终保持能够找到前驱和后继
       insert(-1e9);
       while(n--)
       {
           int opt,x;
           scanf("%d%d",&opt,&x);
           if (opt==1)
               insert(x);
           if (opt==2)
               Delete(x);
           if (opt==3)
           {
               find(x);
               printf("%d
    ",t[t[root].ch[0]].size);
           }
           if (opt==4)
               printf("%d
    ",kth(x+1));
           if (opt==5)
               printf("%d
    ",t[Next(x,0)].val);
           if (opt==6)
               printf("%d
    ",t[Next(x,1)].val);
       }
       return 0;
    }
    
  • 相关阅读:
    3.14周末作业
    3.13作业
    文件处理
    字符编码
    基本数据类型总结
    基本数据类型--------------------集合set()
    python入门009
    作业009
    python入门008
    作业008
  • 原文地址:https://www.cnblogs.com/fridayfang/p/11086692.html
Copyright © 2011-2022 走看看