zoukankan      html  css  js  c++  java
  • 二逼平衡树 题解(树套树)

    题面

    我 想 扇 死 自 己

    void up(int x)
        {
            if(x)
            {
                size[x]=cnt[x];//我TM这行忘了
                if(son[x][0])size[x]+=size[son[x][0]];
                if(son[x][1])size[x]+=size[son[x][1]];
            }
        }

    4个小时!调一道模板!我敲里码!

    上道splay刚因为细节打错浪费了3个小时时间,这次就又**重现了

    不多说了,先把splay抄上10遍,手写!

    -----------以下是正经题解----------------

    第一道树套树:线段树套splay

    对于线段树的每一段区间建splay维护这段的信息

    在合并时:

    排名相加;

    前驱取max;

    后继取min;

    比较麻烦的是查询数值,需要二分答案.

    以数值为值域进行二分,不断询问mid的排名来缩小范围。

    #include<cstdio>
    #include<iostream>
    #include<algorithm>
    #include<cstring>
    using namespace std;
    const int N=4000005,inf=1e9;
    int n,m,a[N];
        int root[N],son[N][3],fa[N],key[N],size[N],type,cnt[N];
        void clear(int x)
        {
            if(!x)return ;
            fa[x]=cnt[x]=son[x][0]=son[x][1]=size[x]=key[x]=0;
        }
        int pre(int k)
        {
            int now=son[root[k]][0];
            while(son[now][1])now=son[now][1];
            return now;
        }
        bool judge(int x)
        {
            return son[fa[x]][1]==x;
        }
        void up(int x)
        {
            if(x)
            {
                size[x]=cnt[x];
                if(son[x][0])size[x]+=size[son[x][0]];
                if(son[x][1])size[x]+=size[son[x][1]];
            }
        }
        void rotate(int x)
        {
            int old=fa[x],oldf=fa[old],lr=judge(x);
            son[old][lr]=son[x][lr^1];
            fa[son[old][lr]]=old;
            son[x][lr^1]=old;
            fa[old]=x;
            fa[x]=oldf;
            if(oldf)son[oldf][son[oldf][1]==old]=x;
            up(old);up(x);
        }
        void splay(int k,int x)
        {
            for(int f;f=fa[x];rotate(x))
                if(fa[f])rotate(judge(x)==judge(f)?f:x);
            root[k]=x;
        }
        void ins(int k,int x)
        {
            if(!root[k])
            {
                type++;
                key[type]=x;
                root[k]=type;
                cnt[type]=size[type]=1;
                fa[type]=son[type][0]=son[type][1]=0;
                return ;
            }
            int now=root[k],f=0;
            while(1)
            {
                if(x==key[now])
                {
                    cnt[now]++;
                    up(now);
                    up(f);
                    splay(k,now);
                    return ;
                }
                f=now;now=son[now][key[now]<x];
                if(!now)
                {
                    type++;
                    size[type]=cnt[type]=1;
                    son[type][0]=son[type][1]=0;
                    son[f][x>key[f]]=type;
                    fa[type]=f;
                    key[type]=x;
                    up(f);splay(k,type);
                    return ;
                }
            }
        }
        int getrank(int k,int x)
        {
            int now=root[k],ans=0;
            while(1)
            {
                if(!now)return ans;
                if(x==key[now])return (son[now][0]?size[son[now][0]]:0)+ans;
                else if(x>key[now])
                {
                    ans+=(son[now][0]?size[son[now][0]]:0)+cnt[now];
                    now=son[now][1];
                }
                else if(x<key[now])now=son[now][0];
            }
        }
        int findpos(int k,int x)
        {
            int now=root[k];
            while(1)
            {
                if(x==key[now])return now;
                else if(x<key[now])now=son[now][0];
                else now=son[now][1];
            }
        }
        int findpre(int k,int x)
        {
            int now=root[k],ans=0;
            while(now)
            {
                if(key[now]<x)
                {
                    if(ans<key[now])ans=key[now];
                    now=son[now][1];
                }
                else now=son[now][0];
            }
            return ans;
        }
        int findnxt(int k,int x)
        {
            int now=root[k],ans=inf;
            while(now)
            {
                if(key[now]>x)
                {
                    if(ans>key[now])ans=key[now];
                    now=son[now][0];
                }
                else now=son[now][1];
            }
            return ans;
        }
        void del(int k,int x)
        {
            int now=findpos(k,x);
            splay(k,now);
            if(cnt[root[k]]>1)
            {
                cnt[root[k]]--;
                up(root[k]);
                return ;
            }
            else if(!son[root[k]][0]&&(!son[root[k]][1]))
            {
                clear(root[k]);
                root[k]=0;
                return ;
            }
            int old=root[k];
            if(son[root[k]][0]*son[root[k]][1]==0)
            {
                if(!son[root[k]][0])root[k]=son[root[k]][1];
                else root[k]=son[root[k]][0];
                fa[root[k]]=0;
                clear(old);
                return ;
            }
            int L=pre(k);
            splay(k,L);
            son[root[k]][1]=son[old][1];
            fa[son[old][1]]=root[k];
            clear(old);
            up(root[k]);
        }
        #define ls(k) k<<1
        #define rs(k) k<<1|1
        void update(int k,int l,int r,int pos,int val)
        {
            ins(k,val);
            if(l==r)return ;
            int mid=l+r>>1;
            if(pos<=mid)update(ls(k),l,mid,pos,val);
            else update(rs(k),mid+1,r,pos,val);
            return ;
        }
        int rank(int k,int l,int r,int L,int R,int val)
        {
            if(l>=L&&r<=R)
            {
                int res=getrank(k,val);
                return res;
            }
            int mid=l+r>>1,res=0;
            if(L<=mid)res+=rank(ls(k),l,mid,L,R,val);
            if(R>mid)res+=rank(rs(k),mid+1,r,L,R,val);
            return res;
        }
        void modify(int k,int l,int r,int pos,int val)
        {
            del(k,a[pos]);
            ins(k,val);
            if(l==r)return ;
            int mid=l+r>>1;
            if(pos<=mid)modify(ls(k),l,mid,pos,val);
            else modify(rs(k),mid+1,r,pos,val);
        }
        int getpre(int k,int l,int r,int L,int R,int val)
        {
            if(l>=L&&r<=R)return findpre(k,val);
            int mid=l+r>>1,res=0;
            if(L<=mid)res=max(res,getpre(ls(k),l,mid,L,R,val));
            if(R>mid)res=max(res,getpre(rs(k),mid+1,r,L,R,val));
            return res;
        }
        int getnxt(int k,int l,int r,int L,int R,int val)
        {
            if(l>=L&&r<=R)return findnxt(k,val);
            int mid=l+r>>1,res=inf;
            if(L<=mid)res=min(res,getnxt(ls(k),l,mid,L,R,val));
            if(R>mid)res=min(res,getnxt(rs(k),mid+1,r,L,R,val));
            return res;
        }
    inline int read()
    {
        int x=0,f=1;char ch=getchar();
        while(ch<'0'||ch>'9')
        {if(ch=='-')f=-1;ch=getchar();}
        while(ch>='0'&&ch<='9')
        {x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}
        return x*f;
    }
    int main()
    {
        n=read();m=read();
        int op,maxx=0;
        for(int i=1;i<=n;i++)
        {
            a[i]=read();
            update(1,1,n,i,a[i]);
            maxx=max(maxx,a[i]);
        }
        while(m--)
        {
            op=read();
            if(op==1)
            {
                int l=read(),r=read(),val=read();
                printf("%d
    ",rank(1,1,n,l,r,val)+1);
            }
            else if(op==2)
            {
                int l=read(),r=read(),val=read();
                int L=0,R=maxx+1;
                while(L!=R)
                {
                    int mid=L+R>>1;
                    int res=rank(1,1,n,l,r,mid);
                    //cout<<"***"<<res<<endl;
                    if(res<val)L=mid+1;
                    else R=mid;
                }
                printf("%d
    ",L-1);
            }
            else if(op==3)
            {
                int pos=read(),val=read();modify(1,1,n,pos,val);
                a[pos]=val;
                maxx=max(maxx,val);
            }
            else if(op==4)
            {
                int l=read(),r=read(),val=read();
                printf("%d
    ",getpre(1,1,n,l,r,val));
            }
            else if(op==5)
            {
                int l=read(),r=read(),val=read();
                printf("%d
    ",getnxt(1,1,n,l,r,val));
            }
        }
        return 0;
    }

    好了。

    从上面那段简短而狗屁不通的“题解”和几乎是抄来的代码可以看出来,是什么让当时的我那么垃圾。

    不求甚解、生搬硬套、懒于思考、依赖题解。

    装模作样打个Splay,考场上没板子真的写得出来?

    如果像本题一样,把普通平衡树的操作放到区间上,显然是无法只用平衡树维护的。解决区间问题最有力的武器就是线段树,所以考虑线段树套平衡树解决。

    对每个线段树区间建一棵平衡树。建树时直接把所有区间都插入该区间的所有元素,单点修改时把沿路的所有线段树区间上的平衡树都进行改动(删除再插入)。

    对于剩下的查询操作,求排名显然可以转化为所有区间小于该数的元素个数之和+1,即$( sum (每个区间求排名结果-1)) +1$,前驱应当是所有区间结果的最大值,同理后继就是最小值。

    但用相同的方式求K大是不太可行的,考虑牺牲一下时间复杂度进行二分答案,每次二分出一个数check它的排名即可。这样的话是3个$log$。

    平衡树使用的是替罪羊树,一是确实好写且容易封装,二是动态开点删点可以避免内存超限。这样就可以直接粗暴地扔到结构体里而不用像Splay一样使用$root[]$数组了。

    上面瞎写的东西我没有删。给自己和大家一个警示以及反面典型。

    #include<cstdio>
    #include<iostream>
    #include<cstring>
    #include<vector>
    using namespace std;
    int read()
    {
        int x=0,f=1;char ch=getchar();
        while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
        while(isdigit(ch))x=x*10+ch-'0',ch=getchar();
        return x*f;
    }
    const int N=1e5+5,inf=2147483647;
    const double al=0.7;
    int n,m,a[N];
    struct Scapegoat
    {
        struct node
        {
            node *l,*r;
            int val,size,cnt;
            bool del;
            bool bad()
            {
                return l->cnt>al*cnt+5||r->cnt>al*cnt+5;
            }
            void up()
            {
                size=!del+l->size+r->size;
                cnt=1+l->cnt+r->cnt;
            }
        };
        node *null,**badtag;
        void dfs(node *k,vector<node*> &v)
        {
            if(k==null)return ;
            dfs(k->l,v);
            if(!k->del)v.push_back(k);
            dfs(k->r,v);
            if(k->del)delete k;
        }
        node *build(vector<node*> &v,int l,int r)
        {
            if(l>=r)return null;
            int mid=l+r>>1;
            node *k=v[mid];
            k->l=build(v,l,mid);
            k->r=build(v,mid+1,r);
            k->up();
            return k;
        }
        void rebuild(node* &k)
        {
            vector<node*> v;
            dfs(k,v);
            k=build(v,0,v.size());
        }
        void insert(int x,node* &k)
        {
            if(k==null)
            {
                k=new node;
                k->l=k->r=null;
                k->del=0;
                k->size=k->cnt=1;
                k->val=x;
                return ;
            }
            ++k->size;++k->cnt;
            if(x>=k->val)insert(x,k->r);
            else insert(x,k->l);
            if(k->bad())badtag=&k;
            else if(badtag!=&null)
                k->cnt-=(*badtag)->cnt-(*badtag)->size;
        }
        void ins(int x,node* &k)
        {
            badtag=&null;
            insert(x,k);
            if(badtag!=&null)rebuild(*badtag);
        }
        int getrk(node *now,int x)
        {
            int ans=1;
            while(now!=null)
            {
                if(now->val>=x)now=now->l;
                else
                {
                    ans+=now->l->size+!now->del;
                    now=now->r;
                }
            }
            return ans;
        }
        int kth(node *now,int x)
        {
            while(now!=null)
            {
                if(!now->del&&now->l->size+1==x)
                    return now->val;
                if(now->l->size>=x)now=now->l;
                else
                {
                    x-=now->l->size+!now->del;
                    now=now->r;
                }
            }
            return -1;
        }
        void erase(node *k,int rk)
        {
            if(!k->del&&rk==k->l->size+1)
            {
                k->del=1;
                --k->size;
                return ;
            }
            --k->size;
            if(rk<=k->l->size+!k->del)erase(k->l,rk);
            else erase(k->r,rk-k->l->size-!k->del);
        }
        node* root;
        Scapegoat()
        {
            null=new node;
            root=null;
        }
    }s[N<<3];
    #define ls(k) (k)<<1
    #define rs(k) (k)<<1|1
    void build(int k,int l,int r)
    {
        for(int i=l;i<=r;i++)
            s[k].ins(a[i],s[k].root);
        if(l==r)return ;
        int mid=l+r>>1;
        build(ls(k),l,mid);
        build(rs(k),mid+1,r);
    }
    int askrk(int k,int l,int r,int L,int R,int val)
    {
        if(L<=l&&R>=r)return s[k].getrk(s[k].root,val)-1;
        int mid=l+r>>1,res=0;
        if(L<=mid)res+=askrk(ls(k),l,mid,L,R,val);
        if(R>mid)res+=askrk(rs(k),mid+1,r,L,R,val);
        return res;
    }
    void update(int k,int l,int r,int pos,int val)
    {
        s[k].erase(s[k].root,s[k].getrk(s[k].root,a[pos]));
        s[k].ins(val,s[k].root);
        if(l==r)return ;
        int mid=l+r>>1;
        if(pos<=mid)update(ls(k),l,mid,pos,val);
        else update(rs(k),mid+1,r,pos,val);
    }
    int askpre(int k,int l,int r,int L,int R,int val)
    {
        if(L<=l&&R>=r)return s[k].kth(s[k].root,s[k].getrk(s[k].root,val)-1);
        int res=-inf,mid=l+r>>1;
        if(L<=mid)
        {
            int ret=askpre(ls(k),l,mid,L,R,val);
            if(ret==-1)res=max(res,-inf);
            else res=max(res,ret);
        }
        if(R>mid)
        {
            int ret=askpre(rs(k),mid+1,r,L,R,val);
            if(ret==-1)res=max(res,-inf);
            else res=max(res,ret);
        }
        return res;
    }
    int asknxt(int k,int l,int r,int L,int R,int val)
    {
        if(L<=l&&R>=r)return s[k].kth(s[k].root,s[k].getrk(s[k].root,val+1));
        int res=inf,mid=l+r>>1;
        if(L<=mid)
        {
            int ret=asknxt(ls(k),l,mid,L,R,val);
            if(ret==-1)res=min(res,inf);
            else res=min(res,ret);
        }
        if(R>mid)
        {
            int ret=asknxt(rs(k),mid+1,r,L,R,val);
            if(ret==-1)res=min(res,inf);
            else res=min(res,ret);
        }
        return res;
    }
    int askth(int L,int R,int val)
    {
        int l=0,r=1e8,res;
        while(l<=r)
        {
            int mid=l+r>>1;
            if(askrk(1,1,n,L,R,mid)+1<=val)res=mid,l=mid+1;
            else r=mid-1;
        }
        return res;
    }
    
    int main()
    {
        n=read();m=read();
        for(int i=1;i<=n;i++)
            a[i]=read();
        build(1,1,n);
        while(m--)
        {
            int op=read();
            if(op==1){int l=read(),r=read(),K=read();printf("%d
    ",askrk(1,1,n,l,r,K)+1);}
            if(op==2){int l=read(),r=read(),K=read();printf("%d
    ",askth(l,r,K));}
            if(op==3){int pos=read(),K=read();update(1,1,n,pos,K);a[pos]=K;}
            if(op==4){int l=read(),r=read(),K=read();printf("%d
    ",askpre(1,1,n,l,r,K));}
            if(op==5){int l=read(),r=read(),K=read();printf("%d
    ",asknxt(1,1,n,l,r,K));}
        }
        return 0;
    }
    
  • 相关阅读:
    python之list,tuple,str,dic简单记录(二)
    python 实现简单点名程序
    python random使用生成随机字符串
    第三方库PIL简单使用
    python Tkinter图形用户编程简单学习(一)
    python之list,tuple,str,dic简单记录(一)
    javascript 对象简单介绍(二)
    iOS使用NSMutableAttributedString 实现富文本(不同颜色字体、下划线等)
    iOS 10 因苹果健康导致闪退 crash
    iOS 10 创建iMessage App
  • 原文地址:https://www.cnblogs.com/Rorschach-XR/p/11019342.html
Copyright © 2011-2022 走看看