zoukankan      html  css  js  c++  java
  • 模板——Splay

    $Splay$

    #include <bits/stdc++.h>
    #define inf (int)1e9
    using namespace std;
    const int N=1e5+100;
    int n,tot,root,val[N],sz[N],son[N][2];
    int fa[N],sf[N],re[N];
    int newnode()
    {
        return tot++;
    }
    void connect(int x,int y,int dir)
    {
        son[y][dir]=x;
        fa[x]=y;
        sf[x]=dir;
    }
    void pushup(int x)
    {
        sz[x]=sz[son[x][0]]+sz[son[x][1]]+re[x];
    }
    void clear()
    {
        val[0]=son[0][0]=son[0][1]=sz[0]=re[0]=fa[0]=sf[0]=0;
    }
    void rotate(int x)
    {
        int f,gf,xd,fd,s;
        f=fa[x];gf=fa[f];
        xd=sf[x];fd=sf[f];
        s=son[x][xd^1];
        connect(x,gf,fd);connect(f,x,xd^1);connect(s,f,xd);
        clear();
        pushup(f);pushup(x);
    }
    void splay(int x,int y)
    {
        while (fa[x]!=y)
        {
            if (fa[fa[x]]==y)
              rotate(x);
            else
            if (sf[fa[x]]==sf[x])
            {
                rotate(fa[x]);
                rotate(x);
            }
            else
            {
                rotate(x);
                rotate(x);
            }
        }
        if (y==0)
          root=x;
    }
    void find(int v)
    {
        int cur=root;
        while (son[cur][v>val[cur]] && val[cur]!=v)
          cur=son[cur][v>val[cur]];
        splay(cur,0);
    }
    int per(int v)
    {
        find(v);
        if (val[root]<v) return root;
        int cur=son[root][0];
        while (son[cur][1]) cur=son[cur][1];
        return cur;
    }
    int suc(int v)
    {
        find(v);
        if (val[root]>v) return root;
        int cur=son[root][1];
        while (son[cur][0]) cur=son[cur][0];
        return cur;
    }
    void insert(int v)
    {
        
        if (root==0)
        {
            int x=newnode();
            val[x]=v;sz[x]=1;
            root=x;
            return;
        }
        int cur=root;
        while (son[cur][v>val[cur]] && val[cur]!=v)
          cur=son[cur][v>val[cur]];
        if (val[cur]==v)
        {
            re[cur]++;sz[cur]++;
            splay(cur,0);
            return;
        }
        int x=newnode();
        val[x]=v;re[x]=sz[x]=1;
        connect(x,cur,v>val[cur]);
        pushup(cur);
        splay(x,0);
    }
    void del(int v)
    {
        int p,s;
        p=per(v);s=suc(v);
        splay(p,0);
        splay(s,p);
        re[son[s][0]]--;sz[son[s][0]]--;
        if (re[son[s][0]]==0)
          son[s][0]=0;
        pushup(s);pushup(p);
    }
    int rk(int v)
    {
        find(v);
        return sz[son[root][0]];
    }
    int kth(int x,int k)
    {
        if (k>sz[son[x][0]]+re[x])
          return kth(son[x][1],k-sz[son[x][0]]-re[x]);
        if (k<=sz[son[x][0]])
          return kth(son[x][0],k);
        return x;
    }
    int main()
    {
        tot=1;
        insert(inf);insert(-inf);
        scanf("%d",&n);
        for (int i=1;i<=n;i++)
        {
            int op,x;
            scanf("%d%d",&op,&x);
            if (op==1) insert(x);
            if (op==2) del(x);
            if (op==3) printf("%d
    ",rk(x));
            if (op==4) printf("%d
    ",val[kth(root,x+1)]);
            if (op==5) printf("%d
    ",val[per(x)]);
            if (op==6) printf("%d
    ",val[suc(x)]);
        }
    }

    维护数列

    #include <bits/stdc++.h>
    #define inf (int)1e9
    using namespace std;
    const int N=500100;
    int n,m,tot,root,a[N];
    struct node
    {
        int sz,val,sum,res,lx,rx,tx,vc;
        int son[2],fa,sf;
    }sh[N+100];
    queue <int> q;
    int newnode()
    {
        int x;
        if (tot>=N)
        {
            x=q.front();
            q.pop();
        }
        else
          x=tot++;
        sh[x].sz=sh[x].val=sh[x].sum=sh[x].res=sh[x].lx=sh[x].rx=sh[x].tx=sh[x].vc=0;
        sh[x].son[0]=sh[x].son[1]=sh[x].sf=sh[x].fa=0;
        sh[x].vc=inf;
        return x;
    }
    void clear(int x)
    {
        q.push(x);
        if (sh[x].son[0]) clear(sh[x].son[0]);
        if (sh[x].son[1]) clear(sh[x].son[1]);
    }
    void connect(int x,int y,int dir)//x->y
    {
        sh[y].son[dir]=x;
        sh[x].fa=y;
        sh[x].sf=dir;
    }
    void pushdown(int x)
    {
        int ls,rs;
        ls=sh[x].son[0];rs=sh[x].son[1];
        if (sh[x].res==1)
        {
            if (ls)
            {
                sh[sh[ls].son[0]].sf^=1;sh[sh[ls].son[1]].sf^=1;
                swap(sh[ls].son[0],sh[ls].son[1]);
                swap(sh[sh[ls].son[0]].lx,sh[sh[ls].son[0]].rx);
                swap(sh[sh[ls].son[1]].lx,sh[sh[ls].son[1]].rx);
                sh[ls].res^=1;
            }
            if (rs)
            {
                sh[sh[rs].son[0]].sf^=1;sh[sh[rs].son[1]].sf^=1;
                swap(sh[rs].son[0],sh[rs].son[1]);
                swap(sh[sh[rs].son[0]].lx,sh[sh[rs].son[0]].rx);
                swap(sh[sh[rs].son[1]].lx,sh[sh[rs].son[1]].rx);
                sh[rs].res^=1;
            }
            sh[x].res=0;
        }
        if (sh[x].vc!=inf)
        {
            if (ls)
            {
                sh[ls].val=sh[ls].vc=sh[x].vc;
                sh[ls].sum=sh[ls].val*sh[ls].sz;
                sh[ls].tx=max(sh[ls].val,sh[ls].val*sh[ls].sz);
                sh[ls].lx=sh[ls].rx=max(0,sh[ls].val*sh[ls].sz);
            }
            if (rs)
            {
                sh[rs].val=sh[rs].vc=sh[x].vc;
                sh[rs].sum=sh[rs].val*sh[rs].sz;
                sh[rs].tx=max(sh[rs].val,sh[rs].val*sh[rs].sz);
                sh[rs].lx=sh[rs].rx=max(0,sh[rs].val*sh[rs].sz);
            }
            sh[x].vc=inf;
        }
    }
    void pushup(int x)
    {
        int ls,rs;
        ls=sh[x].son[0];rs=sh[x].son[1];
        sh[x].sz=sh[ls].sz+sh[rs].sz+1;
        sh[x].sum=sh[ls].sum+sh[rs].sum+sh[x].val;
        sh[x].lx=max(sh[ls].lx,sh[ls].sum+sh[rs].lx+sh[x].val);
        sh[x].rx=max(sh[rs].rx,sh[rs].sum+sh[ls].rx+sh[x].val);
        sh[x].tx=max(sh[ls].tx,max(sh[rs].tx,sh[ls].rx+sh[rs].lx+sh[x].val));
    }
    void rotate(int x)
    {
        int fa,gf,son,xd,fd;
        fa=sh[x].fa;gf=sh[sh[x].fa].fa;
        xd=sh[x].sf;fd=sh[fa].sf;
        son=sh[x].son[xd^1];
        connect(x,gf,fd);connect(fa,x,xd^1);connect(son,fa,xd);
        sh[0].fa=sh[0].son[0]=sh[0].son[1]=sh[0].sf=0;
        pushup(fa);pushup(x);
    }
    void splay(int x,int y)
    {
        while (sh[x].fa!=y)
        {
            if (sh[sh[x].fa].fa==y)
              rotate(x);
            else
            if (sh[x].sf==sh[sh[x].fa].sf)
            {
                rotate(sh[x].fa);
                rotate(x);
            }
            else
            {
                rotate(x);
                rotate(x);
            }
        }
        if (y==0)
          root=x;
    }
    int find(int x,int k)
    {
        pushdown(x);
        if (k>sh[sh[x].son[0]].sz+1)
          return find(sh[x].son[1],k-sh[sh[x].son[0]].sz-1);
        if (k<=sh[sh[x].son[0]].sz)
          return find(sh[x].son[0],k);
        return x;
    }
    int build(int l,int r,int father,int dir)
    {
        int mid=(l+r)>>1;
        int x=newnode();
        sh[x].val=a[mid];
        sh[x].fa=father;sh[x].sf=dir;
        if (l==r)
        {
            sh[x].sz=1;
            sh[x].sum=sh[x].tx=a[mid];
            sh[x].lx=sh[x].rx=max(0,a[mid]);
            return x;
        }
        if (l<=mid-1)
          sh[x].son[0]=build(l,mid-1,x,0);
        if (r>=mid+1)
          sh[x].son[1]=build(mid+1,r,x,1);
        pushup(x);
        return x;
    }
    int split(int l,int r)
    {
        int per,suc;
        per=find(root,l);suc=find(root,r+2);
        splay(per,0);
        splay(suc,per);
        return sh[suc].son[0];
    }
    int main()
    {
        scanf("%d%d",&n,&m);
        for (int i=1;i<=n;i++)
          scanf("%d",&a[i]);
        tot=1;sh[0].tx=a[0]=a[n+1]=-inf;
        root=build(0,n+1,0,0);
        while (m--)
        {
            char ch[20];
            scanf("%s",ch);
            if (ch[0]=='I')
            {
                int pos,tot;
                scanf("%d%d",&pos,&tot);
                for (int i=1;i<=tot;i++)
                  scanf("%d",&a[i]);
                int x=build(1,tot,0,0);
                int A,B;
                A=find(root,pos+1);B=find(root,pos+2);
                splay(A,0);splay(B,A);
                connect(x,B,0);
                pushup(B);pushup(A);
            }
            if (ch[0]=='D')
            {
                int pos,tot;
                scanf("%d%d",&pos,&tot);
                int per,suc;
                per=find(root,pos);suc=find(root,pos+tot+1);
                splay(per,0);
                splay(suc,per);
                clear(sh[suc].son[0]);
                sh[suc].son[0]=0;
                pushup(suc);pushup(per);
            }
            if (ch[0]=='M' && ch[2]=='K')
            {
                int pos,tot,c;
                scanf("%d%d%d",&pos,&tot,&c);
                int x=split(pos,pos+tot-1);
                sh[x].val=sh[x].vc=c;
                sh[x].sum=c*sh[x].sz;
                sh[x].tx=max(sh[x].val,c*sh[x].sz);
                sh[x].lx=sh[x].rx=max(0,c*sh[x].sz);
                pushup(sh[x].fa);pushup(sh[sh[x].fa].fa);
            }
            if (ch[0]=='R')
            {
                int pos,tot;
                scanf("%d%d",&pos,&tot);
                int x=split(pos,pos+tot-1);
                sh[x].res^=1;
                int ls,rs;
                ls=sh[x].son[0];rs=sh[x].son[1];
                sh[ls].sf^=1;sh[rs].sf^=1;
                swap(sh[x].son[0],sh[x].son[1]);
                swap(sh[ls].lx,sh[ls].rx);
                swap(sh[rs].lx,sh[rs].rx);
                pushup(x);pushup(sh[x].fa);pushup(sh[sh[x].fa].fa);
            }
            if (ch[0]=='G')
            {
                int pos,tot;
                scanf("%d%d",&pos,&tot);
                int x=split(pos,pos+tot-1);
                printf("%d
    ",sh[x].sum);
            }
            if (ch[0]=='M' && ch[2]=='X')
            {
                printf("%d
    ",sh[root].tx);
            }
        }
    }
  • 相关阅读:
    Python学习资料
    异常
    I/O
    Python3+迭代器与生成器
    python标准数据类型
    人工智能、机器学习和深度学习
    原地排序和复制排序
    序列化和Json
    登陆加密小程序
    hashlib模块加密用法
  • 原文地址:https://www.cnblogs.com/huangchenyan/p/11216762.html
Copyright © 2011-2022 走看看