zoukankan      html  css  js  c++  java
  • 浅谈伸展树Splay

    普通平衡树

    Description

    设计一种数据结构,支持插入元素,删除元素,查询值为val的元素的排名,查询排名为rnk的值,查询x的前驱、后驱

    Solution

    Splay的基本操作,熟悉一下Splay,这些操作事实上与Treap也能解决。
    为了实现Splay,我们有如下定义及实现方法。
    1.定义结构体Splay,成员同Treap的定义
    2.定义update函数用于维护节点的sum值
    直接加法运算即可
    3.定义connect函数用于建立父子之间的关系
    直接赋值即可
    4.定义rotate函数用于Splay的旋转
    以左旋为例,假设旋转的节点为x,他的父亲为y,他的右子树为B,他的祖父为z,那么我们令B的父亲为y,y的父亲为x,x的父亲为z
    5.定义splay函数用于实现平衡树的伸展操作
    假设当前节点为x,需要伸展到的节点为to,假定x的父亲、祖父为y,z,那么假设x,y都是其父亲的左/ 右儿子,那么我们旋转y,x,否则我们旋转两次x
    6.定义insert函数用于实现元素的插入
    首先根据BST的性质找到插入的位置,如果这个位置有节点那么cnt++,否则新建节点并赋值
    7.定义find函数用于找到某个值的节点的编号并且将这个节点伸展到树根
    根据BST的性质找到位置并调用splay函数实现
    8.定义calc函数用于计算并返回排名为x的数
    根据BST的性质,假定我们要查询排名为x的元素,那么假设x小于等于左子树的大小那么进入左子树,如果x大于左子树+该节点重复的次数那么 进入右子树,否则返回当前节点
    9.定义query函数用于计算并返回某个元素的前驱后驱的编号
    首先调用find()使得目标节点伸展到树根上,如果目标节点的值恰好符合题意那么直接返回,否则利用BST的性质找到答案
    10.定义del函数用于删除某个元素
    假设删除的元素值为x,那么我们找到他的前后驱,执行两次伸展,直接删除即可完成

    Code

      1 #include <bits/stdc++.h>
      2 using namespace std;
      3 const int INF = 2147483647;
      4 inline int read() {
      5     int ret = 0, op = 1;
      6     char c = getchar();
      7     while (!isdigit(c)) {
      8         if (c == '-') op = -1; 
      9         c = getchar();
     10     }
     11     while (isdigit(c)) {
     12         ret = ret * 10 + c - '0';
     13         c = getchar();
     14     }
     15     return ret * op;
     16 }
     17 struct Splay {
     18     int ch[2];
     19     int cnt, sum;
     20     int val, fa;
     21 } a[100010];
     22 int tot, root;
     23 void update(int now) {
     24     a[now].sum = a[a[now].ch[0]].sum + a[a[now].ch[1]].sum + a[now].cnt;
     25 }
     26 void connect(int x, int fa, int op) {
     27     a[x].fa = fa;
     28     a[fa].ch[op] = x;
     29 }
     30 void rotate(int x) {
     31     int y = a[x].fa;
     32     int z = a[y].fa;
     33     int xson = a[y].ch[1] == x ? 1 : 0;
     34     int yson = a[z].ch[1] == y ? 1 : 0;
     35     int B = a[x].ch[xson ^ 1];
     36     connect(B, y, xson); connect(y, x, xson ^ 1); connect(x, z, yson);
     37     update(y); update(x);
     38 }
     39 void splay(int from, int to) {
     40     while (a[from].fa != to) {
     41         int y = a[from].fa;
     42         int z = a[y].fa;
     43         if (z != to)
     44             (a[y].ch[0] == from) ^ (a[z].ch[0] == y) ? update(from) : update(y);
     45         rotate(from);
     46     }
     47     if (to == 0) root = from; 
     48 }
     49 void insert(int val) {
     50     int now = root, fa = 0;
     51     while (now && a[now].val != val) {
     52         fa = now;
     53         now = a[now].ch[val > a[now].val];
     54     }
     55     if(now) {
     56         a[now].cnt++;
     57     }
     58     else {
     59         a[now = ++tot].val = val;
     60         a[tot].sum = a[tot].cnt = 1;
     61         a[tot].fa = fa;
     62         a[tot].ch[0] = a[tot].ch[1] = 0;
     63         if (fa) a[fa].ch[val > a[fa].val] = tot;
     64     }
     65     splay(now, 0);
     66 }
     67 void find(int x) {
     68     int now = root;
     69     if (now == 0) return ;
     70     while (a[now].val != x && a[now].ch[a[now].val < x]) now = a[now].ch[a[now].val < x];
     71     splay(now, 0);
     72 }
     73 int calc(int x) {
     74     int now = root;
     75     if (a[now].sum < x) return 0;
     76     while (1) {
     77         int y = a[now].ch[0];
     78         if (x > a[y].sum + a[now].cnt) {
     79             x -= a[y].sum + a[now].cnt;
     80             now = a[now].ch[1];
     81         }
     82         else if (x <= a[y].sum) now = y;
     83         else return a[now].val;
     84     } 
     85 }
     86 int query(int x, int op) {
     87     find(x);
     88     int now = root;
     89     if ((op && a[now].val > x) || (a[now].val < x && !op)) return now;
     90     now = a[now].ch[op];
     91     while (a[now].ch[op ^ 1]) now = a[now].ch[op ^ 1];
     92     return now;
     93 }
     94 void del(int x) {
     95     int pre = query(x, 0);
     96     int nxt = query(x, 1);
     97     splay(pre, 0); splay(nxt, pre);
     98     int now = a[nxt].ch[0];
     99     if (a[now].cnt > 1) {
    100         a[now].cnt--;
    101         splay(now, 0);
    102         return ;
    103     }
    104     else a[nxt].ch[0] = 0;
    105 }
    106 int main() {
    107     insert(-INF);
    108     insert(INF);
    109     int m = read();
    110     while (m--) {
    111         int op = read(), x = read();
    112         if (op == 1) {
    113             insert(x);
    114         }
    115         else if (op == 2) {
    116             del(x);
    117         }
    118         else if (op == 3) {
    119             find(x);
    120             printf("%d
    ", a[a[root].ch[0]].sum);
    121         }
    122         else if (op == 4) {
    123             printf("%d
    ", calc(x + 1));
    124         }
    125         else if (op == 5) {
    126             printf("%d
    ", a[query(x, 0)].val);
    127         }
    128         else {
    129             printf("%d
    ", a[query(x, 1)].val);
    130         }
    131     } 
    132     return 0;
    133 }
    AC Code

    文艺平衡树

    Description

     写一种数据结构,维护一个序列,并支持区间翻转

    Solution

    Splay的经典操作:维护区间翻转
    对于区间翻转这种操作,由于原序列不能排序,所以我们不能建立一棵权值树,所以我们按照节点的编号建立一棵平衡树。
    相关函数的定义如下:
    1.定义update函数用于维护节点的sum值
    同上
    2.定义splay函数用于实现平衡树的伸展操作
    同上
    3.定义find函数用于找到某个值的节点的编号
    根据BST的性质找到位置即可
    4.定义rotate函数用于用于Splay的旋转
    同上
    5.定义build函数用于建立平衡树
    确切的讲,我们仿照线段树的建树方式,首先建立当前节点,然后递归建立其左右儿子,然后调用update()维护信息即可
    6.定义reverse函数用于实现区间翻转
    假定我们翻转的区间为[l,r],那么我们调用splay()将l-1伸展到根节点,再调用一次splay()将r+1伸展到根节点的右儿子,这样我们只需要在根节点的右儿子的左儿子打一个标记即可。
    7.定义pushdown函数用于下放标记
    类比线段树,每一次翻转操作我们都会在相应的区间打标记,下放标记时将当前节点的标记清空,同时交换两个儿子,并且更新儿子的标记即可
    8.定义connect函数用于建立父子之间的关系
    同上
    9.定义dfs函数用于输出最后的答案
    根据BST的性质,我们对平衡树进行一次中序遍历即可输出最终的序列

    Code

      1 #include <bits/stdc++.h>
      2 using namespace std;
      3 const int INF = 2147483647;
      4 inline int read() {
      5     int ret = 0, op = 1;
      6     char c = getchar();
      7     while (!isdigit(c)) {
      8         if (c == '-') op = -1; 
      9         c = getchar();
     10     }
     11     while (isdigit(c)) {
     12         ret = ret * 10 + c - '0';
     13         c = getchar();
     14     }
     15     return ret * op;
     16 }
     17 int n, m, in[100010], root, tot;
     18 struct Splay {
     19     int val, sum, fa, ch[2], tag, cnt;
     20 } a[100010];
     21 void update(int now) {
     22     if (!now) return ;
     23     a[now].sum = a[now].cnt;
     24     if (a[now].ch[0]) a[now].sum += a[a[now].ch[0]].sum;
     25     if (a[now].ch[1]) a[now].sum += a[a[now].ch[1]].sum;
     26 }
     27 void pushdown(int now) {
     28     if (now && a[now].tag) {
     29         a[a[now].ch[0]].tag ^= 1;
     30         a[a[now].ch[1]].tag ^= 1;
     31         swap(a[now].ch[1], a[now].ch[0]);
     32         a[now].tag = 0;
     33     }
     34 }
     35 void connect(int x, int fa, int op) {
     36     a[fa].ch[op] = x;
     37     a[x].fa = fa;
     38 }
     39 void rotate(int x) {
     40     int y = a[x].fa;
     41     int z = a[y].fa;
     42     pushdown(x);
     43     pushdown(y);
     44     int xson = a[y].ch[1] == x ? 1 : 0;
     45     int yson = a[z].ch[1] == y ? 1 : 0;
     46     int B = a[x].ch[xson ^ 1];
     47     connect(B, y, xson); connect(y, x, xson ^ 1); connect(x, z, yson);
     48     update(y), update(x);
     49 }
     50 void splay(int from, int to) {
     51     while (a[from].fa != to) {
     52         int y = a[from].fa;
     53         int z = a[y].fa;
     54         if (z != to) (a[y].ch[0] == from) ^ (a[z].ch[0] == y) ? rotate(from) : rotate(y);
     55         rotate(from); 
     56     }
     57     if (to == 0) root = from;
     58 }
     59 int build(int fa, int l, int r) {
     60     if (l > r) return 0;
     61     int mid = l + r >> 1;
     62     int now = ++tot;
     63     a[now].val = in[mid];
     64     a[now].cnt++;
     65     a[now].fa = fa;
     66     a[now].sum++;
     67     a[now].ch[0] = 0;
     68     a[now].ch[1] = 0;    
     69     a[now].ch[0] = build(now, l, mid - 1);
     70     a[now].ch[1] = build(now, mid + 1, r);
     71     update(now);
     72     return now;
     73 }
     74 int find(int x) {
     75     int now = root;
     76     while (1) {
     77         pushdown(now);
     78         if (x <= a[a[now].ch[0]].sum) now = a[now].ch[0];
     79         else {
     80             x -= a[a[now].ch[0]].sum + 1;
     81             if (!x) return now;
     82             now = a[now].ch[1];
     83         }
     84     }
     85 }
     86 void reverse(int l, int r) {
     87     l--, r++;
     88     l = find(l);
     89     r = find(r);
     90     splay(l, 0);
     91     splay(r, l);
     92     int now = a[root].ch[1];
     93     now = a[now].ch[0];
     94     a[now].tag ^= 1;
     95 }
     96 void dfs(int now) {
     97     pushdown(now);
     98     if (a[now].ch[0]) dfs(a[now].ch[0]);
     99     if (a[now].val != INF && a[now].val != -INF) printf("%d ", a[now].val);
    100     if (a[now].ch[1]) dfs(a[now].ch[1]);    
    101 }
    102 int main() {
    103     n = read(); m = read();
    104     in[1] = -INF; in[n + 2] = INF;
    105     for (register int i = 1; i <= n; ++i) in[i + 1] = i;
    106     root = build(0, 1, n + 2);
    107     for (register int i = 1; i <= m; ++i) {
    108         int x = read() + 1, y = read() + 1;
    109         reverse(x, y);
    110     }
    111     dfs(root);
    112     return 0;
    113 }
    AC Code
  • 相关阅读:
    Codeforces Beta Round #92 (Div. 2 Only) B. Permutations 模拟
    POJ 3281 Dining 最大流 Dinic算法
    POJ 2441 Arrange the BUlls 状压DP
    URAL 1152 Faise Mirrors 状压DP 简单题
    URAL 1039 Anniversary Party 树形DP 水题
    URAL 1018 Binary Apple Tree 树形DP 好题 经典
    pytorch中的forward前向传播机制
    .data()与.detach()的区别
    Argparse模块
    pytorch代码调试工具
  • 原文地址:https://www.cnblogs.com/shl-blog/p/11267488.html
Copyright © 2011-2022 走看看