zoukankan      html  css  js  c++  java
  • splay 伸展树 代码实现

    Splay 概念文章: http://blog.csdn.net/naivebaby/article/details/1357734

    叉姐 数组实现: https://github.com/ftiasch/mithril/blob/master/2012-10-24/I.cpp#L43

    Vani 指针实现: https://github.com/Azure-Vani/acm-icpc/blob/master/spoj/SEQ2.cpp

    hdu 1890 写法: http://blog.csdn.net/fp_hzq/article/details/8087431

    HH splay写法: http://www.notonlysuccess.com/index.php/splay-tree/

    poj 3468 HH写法

    View Code
      1 /*
      2 http://acm.pku.edu.cn/JudgeOnline/problem?id=3468
      3 区间跟新,区间求和
      4 */
      5 #include <cstdio>
      6 #define keyTree (ch[ ch[root][1] ][0])
      7 const int maxn = 222222;
      8 struct SplayTree{
      9     int sz[maxn];
     10     int ch[maxn][2];
     11     int pre[maxn];
     12     int root , top1 , top2;
     13     int ss[maxn] , que[maxn];
     14  
     15     inline void Rotate(int x,int f) {
     16         int y = pre[x];
     17         push_down(y);
     18         push_down(x);
     19         ch[y][!f] = ch[x][f];
     20         pre[ ch[x][f] ] = y;
     21         pre[x] = pre[y];
     22         if(pre[x]) ch[ pre[y] ][ ch[pre[y]][1] == y ] = x;
     23         ch[x][f] = y;
     24         pre[y] = x;
     25         push_up(y);
     26     }
     27     inline void Splay(int x,int goal) {
     28         push_down(x);
     29         while(pre[x] != goal) {
     30             if(pre[pre[x]] == goal) {
     31                 Rotate(x , ch[pre[x]][0] == x);
     32             } else {
     33                 int y = pre[x] , z = pre[y];
     34                 int f = (ch[z][0] == y);
     35                 if(ch[y][f] == x) {
     36                     Rotate(x , !f) , Rotate(x , f);
     37                 } else {
     38                     Rotate(y , f) , Rotate(x , f);
     39                 }
     40             }
     41         }
     42         push_up(x);
     43         if(goal == 0) root = x;
     44     }
     45     inline void RotateTo(int k,int goal) {//把第k位的数转到goal下边
     46         int x = root;
     47         push_down(x);
     48         while(sz[ ch[x][0] ] != k) {
     49             if(k < sz[ ch[x][0] ]) {
     50                 x = ch[x][0];
     51             } else {
     52                 k -= (sz[ ch[x][0] ] + 1);
     53                 x = ch[x][1];
     54             }
     55             push_down(x);
     56         }
     57         Splay(x,goal);
     58     }
     59     inline void erase(int x) {//把以x为祖先结点删掉放进内存池,回收内存
     60         int father = pre[x];
     61         int head = 0 , tail = 0;
     62         for (que[tail++] = x ; head < tail ; head ++) {
     63             ss[top2 ++] = que[head];
     64             if(ch[ que[head] ][0]) que[tail++] = ch[ que[head] ][0];
     65             if(ch[ que[head] ][1]) que[tail++] = ch[ que[head] ][1];
     66         }
     67         ch[ father ][ ch[father][1] == x ] = 0;
     68         pushup(father);
     69     }
     70     //以上一般不修改//////////////////////////////////////////////////////////////////////////////
     71     void debug() {printf("%d\n",root);Treaval(root);}
     72     void Treaval(int x) {
     73         if(x) {
     74             Treaval(ch[x][0]);
     75             printf("结点%2d:左儿子 %2d 右儿子 %2d 父结点 %2d size = %2d ,val = %2d\n",x,ch[x][0],ch[x][1],pre[x],sz[x],val[x]);
     76             Treaval(ch[x][1]);
     77         }
     78     }
     79     //以上Debug
     80  
     81  
     82     //以下是题目的特定函数:
     83     inline void NewNode(int &x,int c) {
     84         if (top2) x = ss[--top2];//用栈手动压的内存池
     85         else x = ++top1;
     86         ch[x][0] = ch[x][1] = pre[x] = 0;
     87         sz[x] = 1;
     88  
     89         val[x] = sum[x] = c;/*这是题目特定函数*/
     90         add[x] = 0;
     91     }
     92  
     93     //把延迟标记推到孩子
     94     inline void push_down(int x) {/*这是题目特定函数*/
     95         if(add[x]) {
     96             val[x] += add[x];
     97             add[ ch[x][0] ] += add[x];
     98             add[ ch[x][1] ] += add[x];
     99             sum[ ch[x][0] ] += (long long)sz[ ch[x][0] ] * add[x];
    100             sum[ ch[x][1] ] += (long long)sz[ ch[x][1] ] * add[x];
    101             add[x] = 0;
    102         }
    103     }
    104     //把孩子状态更新上来
    105     inline void push_up(int x) {
    106         sz[x] = 1 + sz[ ch[x][0] ] + sz[ ch[x][1] ];
    107         /*这是题目特定函数*/
    108         sum[x] = add[x] + val[x] + sum[ ch[x][0] ] + sum[ ch[x][1] ];
    109     }
    110  
    111     /*初始化*/
    112     inline void makeTree(int &x,int l,int r,int f) {
    113         if(l > r) return ;
    114         int m = (l + r)>>1;
    115         NewNode(x , num[m]);        /*num[m]权值改成题目所需的*/
    116         makeTree(ch[x][0] , l , m - 1 , x);
    117         makeTree(ch[x][1] , m + 1 , r , x);
    118         pre[x] = f;
    119         push_up(x);
    120     }
    121     inline void init(int n) {/*这是题目特定函数*/
    122         ch[0][0] = ch[0][1] = pre[0] = sz[0] = 0;
    123         add[0] = sum[0] = 0;
    124  
    125         root = top1 = 0;
    126         //为了方便处理边界,加两个边界顶点
    127         NewNode(root , -1);
    128         NewNode(ch[root][1] , -1);
    129         pre[top1] = root;
    130         sz[root] = 2;
    131  
    132  
    133         for (int i = 0 ; i < n ; i ++) scanf("%d",&num[i]);
    134         makeTree(keyTree , 0 , n-1 , ch[root][1]);
    135         push_up(ch[root][1]);
    136         push_up(root);
    137     }
    138     /*更新*/
    139     inline void update( ) {/*这是题目特定函数*/
    140         int l , r , c;
    141         scanf("%d%d%d",&l,&r,&c);
    142         RotateTo(l-1,0);
    143         RotateTo(r+1,root);
    144         add[ keyTree ] += c;
    145         sum[ keyTree ] += (long long)c * sz[ keyTree ];
    146     }
    147     /*询问*/
    148     inline void query() {/*这是题目特定函数*/
    149         int l , r;
    150         scanf("%d%d",&l,&r);
    151         RotateTo(l-1 , 0);
    152         RotateTo(r+1 , root);
    153         printf("%lld\n",sum[keyTree]);
    154     }
    155  
    156  
    157     /*这是题目特定变量*/
    158     int num[maxn];
    159     int val[maxn];
    160     int add[maxn];
    161     long long sum[maxn];
    162 }spt;
    163  
    164  
    165 int main() {
    166     int n , m;
    167     scanf("%d%d",&n,&m);
    168     spt.init(n);
    169     while(m --) {
    170         char op[2];
    171         scanf("%s",op);
    172         if(op[0] == 'Q') {
    173             spt.query();
    174         } else {
    175             spt.update();
    176         }
    177     }
    178     return 0;
    179 }

    叉姐 

    View Code
      1 #include <cstdio>
      2 #include <cstring>
      3 #include <vector>
      4 #include <climits>
      5 #include <algorithm>
      6 using namespace std;
      7 
      8 const int N = 200000;
      9 const int M = 1 + (N << 1);
     10 const int EMPTY = M - 1;
     11 
     12 const int MOD = 99990001;
     13 
     14 int nodeCount, type[M], parent[M], children[M][2], id[M];
     15 
     16 int scale[M], delta[M], weight[M], size[M], minimum[M];
     17 
     18 void update(int x) {
     19     size[x] = size[children[x][0]] + 1 + size[children[x][1]];
     20     minimum[x] = min(min(minimum[children[x][0]], minimum[children[x][1]]), id[x]);
     21 }
     22 
     23 void modify(int x, int k, int b) {
     24     weight[x] = ((long long)k * weight[x] + b) % MOD;
     25     scale[x] = (long long)k * scale[x] % MOD;
     26     delta[x] = ((long long)k * delta[x] + b) % MOD;
     27 }
     28 
     29 void pushDown(int x) {
     30     for (int i = 0; i < 2; ++ i) {
     31         if (children[x][i] != EMPTY) {
     32             modify(children[x][i], scale[x], delta[x]);
     33         }
     34     }
     35     scale[x] = 1;
     36     delta[x] = 0;
     37 }
     38 
     39 void rotate(int x) {
     40     int t = type[x];
     41     int y = parent[x];
     42     int z = children[x][1 ^ t];
     43     type[x] = type[y];
     44     parent[x] = parent[y];
     45     if (type[x] != 2) {
     46         children[parent[x]][type[x]] = x;
     47     }
     48     type[y] = 1 ^ t;
     49     parent[y] = x;
     50     children[x][1 ^ t] = y;
     51     if (z != EMPTY) {
     52         type[z] = t;
     53         parent[z] = y;
     54     }
     55     children[y][t] = z;
     56     update(y);
     57 }
     58 
     59 void splay(int x) {
     60     if (x == EMPTY) {
     61         return;
     62     }
     63     vector <int> stack(1, x);
     64     for (int i = x; type[i] != 2; i = parent[i]) {
     65         stack.push_back(parent[i]);
     66     }
     67     while (!stack.empty()) {
     68         pushDown(stack.back());
     69         stack.pop_back();
     70     }
     71     while (type[x] != 2) {
     72         int y = parent[x];
     73         if (type[x] == type[y]) {
     74             rotate(y);
     75         } else {
     76             rotate(x);
     77         }
     78         if (type[x] == 2) {
     79             break;
     80         }
     81         rotate(x);
     82     }
     83     update(x);
     84 }
     85 
     86 int goLeft(int x) {
     87     while (children[x][0] != EMPTY) {
     88         x = children[x][0];
     89     }
     90     return x;
     91 }
     92 
     93 int join(int x, int y) {
     94     if (x == EMPTY || y == EMPTY) {
     95         return x != EMPTY ? x : y;
     96     }
     97     y = goLeft(y);
     98     splay(y);
     99     splay(x);
    100     type[x] = 0;
    101     parent[x] = y;
    102     children[y][0] = x;
    103     update(y);
    104     return y;
    105 }
    106 
    107 pair <int, int> split(int x) {
    108     splay(x);
    109     int a = children[x][0];
    110     int b = children[x][1];
    111     children[x][0] = children[x][1] = EMPTY;
    112     if (a != EMPTY) {
    113         type[a] = 2;
    114         parent[a] = EMPTY;
    115     }
    116     if (b != EMPTY) {
    117         type[b] = 2;
    118         parent[b] = EMPTY;
    119     }
    120     return make_pair(a, b);
    121 }
    122 
    123 int newNode(int init, int vid) {
    124     int x = nodeCount ++;
    125     type[x] = 2;
    126     parent[x] = children[x][0] = children[x][1] = EMPTY;
    127     id[x] = vid;
    128     weight[x] = init;
    129     scale[x] = 1;
    130     delta[x] = 0;
    131     update(x);
    132     return x;
    133 }
    134 
    135 int n;
    136 int edgeCount, firstEdge[N], to[M], nextEdge[M], initWeight[N], position[M];
    137 
    138 int root;
    139 
    140 void addEdge(int u, int v) {
    141     to[edgeCount] = v;
    142     nextEdge[edgeCount] = firstEdge[u];
    143     firstEdge[u] = edgeCount ++;
    144 }
    145 
    146 void dfs(int p, int u) {
    147     for (int iter = firstEdge[u]; iter != -1; iter = nextEdge[iter]) {
    148         int v = to[iter];
    149         if (v != p) {
    150             position[iter] = nodeCount;
    151             root = join(root, newNode(initWeight[iter >> 1], min(u, v)));
    152             dfs(u, v);
    153             position[iter ^ 1] = nodeCount;
    154             root = join(root, newNode(initWeight[iter >> 1], min(u, v)));
    155         }
    156     }
    157 }
    158 
    159 int getRank(int x) { // 1-based
    160     splay(x);
    161     return size[children[x][0]] + 1;
    162 }
    163 
    164 void print(int root) {
    165     if (root != EMPTY) {
    166         printf("[ ");
    167         print(children[root][0]);
    168         printf(" %d ", root);
    169         print(children[root][1]);
    170         printf(" ]");
    171     }
    172 }
    173 
    174 int main() {
    175     size[EMPTY] = 0;
    176     minimum[EMPTY] = INT_MAX;
    177     parent[EMPTY] = 2;
    178     scanf("%d", &n);
    179     edgeCount = 0;
    180     memset(firstEdge, -1, sizeof(firstEdge));
    181     for (int i = 0; i < n - 1; ++ i) {
    182         int a, b;
    183         scanf("%d%d%d", &a, &b, initWeight + i);
    184         a --;
    185         b --;
    186         addEdge(a, b);
    187         addEdge(b, a);
    188     }
    189     nodeCount = 0;
    190     root = EMPTY;
    191     dfs(-1, 0);
    192     for (int i = 0; i < n - 1; ++ i) {
    193         int id;
    194         scanf("%d", &id);
    195         id --;
    196 
    197         int a = position[id << 1];
    198         int b = position[(id << 1) ^ 1];
    199         if (getRank(a) > getRank(b)) {
    200             swap(a, b);
    201         }
    202         splay(a);
    203 
    204         int output = weight[a];
    205         printf("%d\n", output);
    206         fflush(stdout);
    207 
    208         pair <int, int> ret1 = split(a);
    209         pair <int, int> ret2 = split(b);
    210         int x = ret1.first;
    211         int y = ret2.first;
    212         int z = ret2.second;
    213         x = join(z, x);
    214         splay(x);
    215         splay(y);
    216         if (size[x] > size[y]) {
    217             swap(x, y);
    218         }
    219         if (size[x] == size[y] && minimum[x] > minimum[y]) {
    220             swap(x, y);
    221         }
    222         modify(x, output, 0);
    223         modify(y, 1, output);
    224     }
    225     return 0;
    226 }

    spoj SEQ2

    Vani 

    View Code
      1 #include <cstdio>
      2 #include <cctype>
      3 #include <algorithm>
      4 #include <cstring>
      5 
      6 using namespace std;
      7 
      8 namespace Solve {
      9     const int MAXN = 500010;
     10     const int inf = 500000000;
     11 
     12     char BUF[50000000], *pos = BUF;
     13     inline int ScanInt(void) {
     14         int r = 0, d = 0;
     15         while (!isdigit(*pos) && *pos != '-') pos++;
     16         if (*pos != '-') r = *pos - 48; else d = 1; pos++;
     17         while ( isdigit(*pos)) r = r * 10 + *pos++ - 48;
     18         return d ? -r : r;
     19     }
     20     inline void ScanStr(char *st) {
     21         int l = 0;
     22         while (!(isupper(*pos) || *pos == '-')) pos++;
     23         st[l++] = *pos++;
     24         while (isupper(*pos) || *pos == '-') st[l++] = *pos++; st[l] = 0;
     25     }
     26 
     27     struct Node {
     28         Node *ch[2], *p;
     29         int v, lmax, rmax, m, same, rev, sum, size;
     30         inline bool dir(void) {return this == p->ch[1];}
     31         inline void SetC(Node *x, bool d) {ch[d] = x, x->p = this;}
     32         inline void Update(void) {
     33             Node *L = ch[0], *R = ch[1];
     34             size = L->size + R->size + 1;
     35             m = max(L->m, R->m);
     36             m = max(m, L->rmax + v + R->lmax);
     37             lmax = max(L->lmax, L->sum + v + R->lmax);
     38             rmax = max(R->rmax, R->sum + v + L->rmax);
     39             sum = L->sum + R->sum + v;
     40         }
     41         inline void Rev(void) {
     42             if (v == -inf) return;
     43             rev ^= 1;
     44             swap(ch[0], ch[1]);
     45             swap(lmax, rmax);
     46         }
     47         inline void Same(int u) {
     48             if (v == -inf) return;
     49             same = u;
     50             sum = u * size;
     51             if (sum > 0) lmax = rmax = m = sum; else lmax = 0, rmax = 0, m = u;
     52             v = u;
     53         }
     54         inline void Down(void) {
     55             if (rev) {
     56                 ch[0]->Rev(), ch[1]->Rev();
     57                 rev = 0;
     58             }
     59             if (same != -inf) {
     60                 ch[0]->Same(same), ch[1]->Same(same);
     61                 same = -inf;
     62             }
     63         }
     64     } Tnull, *null = &Tnull;
     65 
     66     class Splay {public:
     67         Node *root;
     68         inline void rotate(Node *x) {
     69             Node *p = x->p; bool d = x->dir();
     70             p->Down(); x->Down();
     71             p->p->SetC(x, p->dir());
     72             p->SetC(x->ch[!d], d);
     73             x->SetC(p, !d);
     74             p->Update();
     75         }
     76         inline void splay(Node *x, Node *G) {
     77             if (G == null) root = x;
     78             while (x->p != G) {
     79                 if (x->p->p == G) {rotate(x); break;}
     80                 else {if (x->dir() == x->p->dir()) rotate(x->p), rotate(x); else rotate(x), rotate(x);}
     81             }
     82             x->Update();
     83         }
     84         inline Node *Select(int k) {
     85             Node *t = root;
     86             while (t->Down(), t->ch[0]->size + 1 != k) {
     87                 if (k > t->ch[0]->size + 1) k -= t->ch[0]->size + 1, t = t->ch[1];
     88                 else t = t->ch[0];
     89             }
     90             splay(t, null);
     91             return t;
     92         }
     93         inline Node *getInterval(int l, int r) {
     94             Node *L = Select(l), *R = Select(r + 2);
     95             splay(L, null); splay(R, L);
     96             L->Down(); R->Down();
     97             return R;
     98         }
     99         inline void Insert(int pos, Node *x) {
    100             Node *now = getInterval(pos + 1, pos);
    101             now->SetC(x, 0);
    102             now->Update(); root->Update();
    103         }
    104         inline void Delete(int l, int r) {
    105             Node *now = getInterval(l, r);
    106             now->ch[0] = null;
    107             now->Update(); root->Update();
    108         }
    109         inline void Make(int l, int r, int c) {
    110             Node *now = getInterval(l, r);
    111             now->ch[0]->Same(c);
    112             now->Update(); root->Update();
    113         }
    114         inline void Reverse(int l, int r) {
    115             Node *now = getInterval(l, r);
    116             now->ch[0]->Rev();
    117             now->Update(); root->Update();
    118         }
    119         inline int Sum(int l, int r) {
    120             Node *now = getInterval(l, r);
    121             root->Down(); now->Down();
    122             return now->ch[0]->sum;
    123         }
    124         inline int maxSum(int l, int r) {
    125             Node *now = getInterval(l, r);
    126             root->Down(); now->Down();
    127             return now->ch[0]->m;
    128         }
    129         inline Node* Renew(int c) {
    130             Node *ret = new Node;
    131             ret->ch[0] = ret->ch[1] = ret->p = null; ret->size = 1;
    132             ret->Same(c); ret->same = -inf;
    133             return ret;
    134         }
    135         inline Node* Build(int l, int r, int *a) {
    136             if (l > r) return null;
    137             int mid = (l + r) >> 1;
    138             Node *ret = Renew(a[mid]);
    139             ret->ch[0] = Build(l, mid - 1, a);
    140             ret->ch[1] = Build(mid + 1, r, a);
    141             ret->ch[0]->p = ret->ch[1]->p = ret;
    142             ret->Update();
    143             return ret;
    144         }
    145         inline void P(Node *t) {
    146             if (t == null) return;
    147             t->Down(); t->Update();
    148             P(t->ch[0]);
    149             printf("%d ", t->v);
    150             P(t->ch[1]);
    151         }
    152     }T;
    153 
    154 
    155     int a[MAXN]; char ch[10];
    156 
    157     inline void solve(void) {
    158         fread(BUF, 1, 50000000, stdin);
    159         null->same = null->m = null->v = -inf;
    160         int kase = ScanInt();
    161         while (kase--) {
    162             int n = ScanInt(), m = ScanInt();
    163             for (int i = 1; i <= n; i++) a[i] = ScanInt();
    164             T.root = T.Build(0, n + 1, a);
    165             for (int i = 1; i <= m; i++) {
    166                 ScanStr(ch);
    167                 if (strcmp(ch, "INSERT") == 0) {
    168                     int pos = ScanInt(), t = ScanInt();
    169                     for (int j = 1; j <= t; j++) a[j] = ScanInt();
    170                     Node *tmp = T.Build(1, t, a);
    171                     T.Insert(pos, tmp);
    172                 }
    173                 if (strcmp(ch, "DELETE") == 0) {
    174                     int l = ScanInt(), r = ScanInt(); r = l + r - 1;
    175                     T.Delete(l, r);
    176                 }
    177                 if (strcmp(ch, "MAKE-SAME") == 0) {
    178                     int l = ScanInt(), r = ScanInt(), c = ScanInt(); r = l + r - 1;
    179                     T.Make(l, r, c);
    180                 }
    181                 if (strcmp(ch, "REVERSE") == 0) {
    182                     int l = ScanInt(), r = ScanInt(); r = l + r - 1;
    183                     T.Reverse(l, r);
    184                 }
    185                 if (strcmp(ch, "GET-SUM") == 0) {
    186                     int l = ScanInt(), r = ScanInt(); r = l + r - 1;
    187                     int ret = T.Sum(l, r);
    188                     printf("%d\n", ret);
    189                 }
    190                 if (strcmp(ch, "MAX-SUM") == 0) {
    191                     int ret = T.maxSum(1, T.root->size - 2);
    192                     printf("%d\n", ret);
    193                 }
    194             }
    195         }
    196     }
    197 }
    198 
    199 int main(void) {
    200     freopen("in", "r", stdin);
    201     Solve::solve();
    202     return 0;
    203 }
  • 相关阅读:
    错误解决记录-------------验证启动HDFS时遇到的错误
    Spark环境搭建(一)-----------HDFS分布式文件系统搭建
    Synergy简单使用小记
    python基础一 ------排序和查找算法
    Scrapy基础(十四)————Scrapy实现知乎模拟登陆
    Scrapy基础(十四)————知乎模拟登陆
    Scrapy基础(十三)————ItemLoader的简单使用
    Scrapy基础(十二)————异步导出Item数据到Mysql中
    简单python爬虫练习 E站本爬取
    7-4 jmu-Java&Python-统计文字中的单词数量并按出现次数排序 (25分)
  • 原文地址:https://www.cnblogs.com/yefeng1627/p/3006308.html
Copyright © 2011-2022 走看看