zoukankan      html  css  js  c++  java
  • 主席树——多棵线段树的集合

    主席树:

    (不要管名字)

    我们有的时候,会遇到很多种情况,对于每一种情况,都需要通过线段树的操作实现。

    碰巧的是,相邻两种情况下的线段树的差异不大。(总体的差异次数是O(N)级别的,均摊就是O(常数)的了)

    显然的是,我们不能对于每种情况都建造一棵线段树。n^n 空间直接MLE无疑。

    救命稻草是:发现相邻两种情况下的线段树的差异不大。

    所以,我们是否可以让不同的线段树共用同一个节点呢?!?!?

    这就是主席树的本质。也是精妙之处所在。

    代码实现不是很麻烦。

    我一般用传返回值形式,每次返回一个节点编号,便于设置儿子编号。比较方便。

    注意的是,我们必须记录lson,rson,不能采用x<<1,x<<1|1的形式。因为没有这样的规律可循。

    你不知道子节点和自己有什么关系。(这是谁家的孩子?公家的)

    经典例题:

    1.区间第k小(大)。

    离散化必须的。

    对于每一个区间节点开一个权值线段树 。i的线段树的节点l~r表示,在真正的区间1~i中,大小在l~r的数出现的次数。

    记录每个线段树节点根的所在位置。

    查询的时候,l-1,r两棵线段树同时出发,区间[a,b]sum值做一个差,就是l~r这个区间内,数值在[a,b]之间的数的个数。

    对于区间第k小,选择左儿子区间做差,u<k,就进入右儿子,同时k-=u

    否则进入左儿子。

    区间第k大正相反。

    对于n棵主席树,相邻两个主席树i,i-1只在i的数值位置的值不一样。

    所以,相邻的主席树只会增加logn个节点。

    总空间复杂度nlogn

    代码:

    #include<bits/stdc++.h>
    using namespace std;
    const int N=2e5+10;
    int n,m;
    int a[N],num[N];
    int cnt,rt[N];
    int li(int x){
        return lower_bound(num+1,num+cnt+1,x)-num;
    }
    struct node{
        int sum,lson,rson;
    }t[N*18];
    int tot;
    int add(int x,int l,int r,int c){
        tot++;
        int ret=tot;
        t[tot].sum=t[x].sum+1;
        if(l==r){
            return ret;
        }
        int mid=(l+r)>>1;
        if(c<=mid){
            t[tot].rson=t[x].rson;
            t[tot].lson=add(t[x].lson,l,mid,c);
        } 
        else{
            t[tot].lson=t[x].lson;
            t[tot].rson=add(t[x].rson,mid+1,r,c);
        }
        return ret;
    }
    int query(int x,int y,int l,int r,int k){
        if(l==r){
            return l;
        }
        int mid=(l+r)>>1;
        int u=t[t[y].lson].sum-t[t[x].lson].sum;
        if(u>=k){
            return query(t[x].lson,t[y].lson,l,mid,k);
        }
        else{
            return query(t[x].rson,t[y].rson,mid+1,r,k-u);
        }
    }
    int main()
    {
        scanf("%d%d",&n,&m);int df;
        for(int i=1;i<=n;i++) scanf("%d",&a[i]),num[++cnt]=a[i];
        sort(num+1,num+cnt+1);
        cnt=unique(num+1,num+cnt+1)-num-1;
        rt[0]=++tot;
        for(int i=1;i<=n;i++) rt[i]=add(rt[i-1],1,cnt,li(a[i]));
        int op,l,r;
        while(m--){
            scanf("%d%d%d",&l,&r,&op);
            int ot=query(rt[l-1],rt[r],1,cnt,op);
            printf("%d
    ",num[ot]);
        }
        ret

     

    2.spoj 10628 Count on a tree

    给定一棵N个节点的树,每个点有一个权值,对于M个询问(u,v,k),你需要回答u xor lastans和v这两个节点间第K小的点权。其中lastans是上一个询问的答案,初始为0,即第一个询问的u是明文。

    树上区间第k小,从根节点开始每个节点建主席树,从父亲版本跟新过来。

    对于x的主席树,[a,b]表示,从根到x的路径上,点权值在[a,b]之间的点数。

    儿子和父亲的主席树的差异,仅在x的点权数值的位置上。

    类似树上差分,

    对于询问的x,y,设它们的最近公共祖先是lca

    四棵树同时走,x,y路径上的点权点数的信息,就是sum[x]+sum[y]-sum[lca]-sum[fa[lca]](和点覆盖的树上差分类似)

    剩下的同理了

    代码:

    #include<bits/stdc++.h>
    using namespace std;
    const int N=100000+10;
    int n,m;
    int w[N];
    int rt[N];
    int num[N],mem;
    struct node{
        int nxt,to;
    }e[2*N];
    int hd[N],cnt;
    int li(int x){
        return lower_bound(num+1,num+mem+1,x)-num;
    }
    void add(int x,int y){
        e[++cnt].nxt=hd[x];e[cnt].to=y;hd[x]=cnt;
    }
    int dep[N];
    int fa[N][30];
    
    int lca(int x,int y){
        if(dep[x]<dep[y]) swap(x,y);
        for(int i=28;i>=0;i--){
            if(dep[fa[x][i]]>=dep[y])
             x=fa[x][i];
        }
        if(x==y) return x;
        for(int i=28;i>=0;i--){
            if(fa[x][i]!=fa[y][i])
             x=fa[x][i],y=fa[y][i];
        }
        return fa[x][0];
    }
    struct tr{
        int sum,lson,rson;
        #define ls(x) t[x].lson
        #define rs(x) t[x].rson
        #define s(x) t[x].sum
    }t[N*30];
    int tot;
    int upda(int x,int l,int r,int c){
        int ret=++tot;
        t[ret].sum=t[x].sum+1;
        if(l==r) return ret;
        int mid=l+r>>1;
        if(c<=mid){
            rs(ret)=rs(x);
            ls(ret)=upda(ls(x),l,mid,c);
        }
        else{
            ls(ret)=ls(x);
            rs(ret)=upda(rs(x),mid+1,r,c);
        }
        return ret;
    }
    int query(int x,int y,int z,int p,int l,int r,int k){
        if(l==r) return l;
        int mid=l+r>>1;
        int u=s(ls(x))+s(ls(y))-s(ls(p))-s(ls(z));
        if(k<=u)return query(ls(x),ls(y),ls(z),ls(p),l,mid,k);
        else return query(rs(x),rs(y),rs(z),rs(p),mid+1,r,k-u); 
    }
    void dfs(int x,int d){
        dep[x]=d;
        rt[x]=upda(rt[fa[x][0]],1,mem,li(w[x]));
        for(int i=hd[x];i;i=e[i].nxt){
            int y=e[i].to;
            if(y==fa[x][0]) continue;
            fa[y][0]=x;
            dfs(y,d+1);    
        }
    }
    int main()
    {
        scanf("%d%d",&n,&m);
        for(int i=1;i<=n;i++) scanf("%d",&w[i]),num[++mem]=w[i];
        sort(num+1,num+mem+1);
        mem=unique(num+1,num+mem+1)-num-1;
        
        int x,y;
        for(int i=1;i<=n-1;i++){
            scanf("%d%d",&x,&y);
            add(x,y);add(y,x);
        }
        dfs(1,1);
        dep[0]=-1;
        for(int i=1;(1<<i)<=n;i++)
         for(int j=1;j<=n;j++){
             fa[j][i]=fa[fa[j][i-1]][i-1];
        }
        int op;
        int las=0;
        while(m--){
            scanf("%d%d%d",&x,&y,&op);
            x^=las;
            int anc=lca(x,y);
            //cout<<" anc "<<anc<<endl;
            int ot=query(rt[x],rt[y],rt[anc],rt[fa[anc][0]],1,mem,op);
            las=num[ot];
            printf("%d
    ",num[ot]);
        }
        return 0;
    }
    Count on a tree

    3.[国家集训队]middle

    一个长度为n的序列a,设其排过序之后为b,其中位数定义为b[n/2],其中a,b从0开始标号,除法取下整。

    给你一个长度为n的序列s。

    回答Q个这样的询问:s的左端点在[a,b]之间,右端点在[c,d]之间的子序列中,最大的中位数。

    其中a<b<c<d。

    位置也从0开始标号。

    我会使用一些方式强制你在线。

    这个题就比较的巧妙了。不像之前的套路性的第k问题。

    这个是真真正正地用主席树替代了线段树。

    首先,对于区间中位数一个比较套路的做法是:

    二分一个答案mid,把所有>=mid的数值设成1,<mid的值设为-1

    查询区间内的和是否>=0(这个题是>=0,题意中,偶数项的中位数是中间的那两个靠后的那一个)

    是,中位数应该更大,

    否则,中位数只能更小。

    先不考虑复杂度。

    给[a,d]区间的数赋值为1、-1

    这个题,区间都不是固定的。

    但是,[b+1,c-1]的值是必选的。计算一下这个区间的和。

    对于[a,b],[c,d]

    因为要让中位数尽可能的大。

    所以,争取选择尽可能多的1

    找一个[a,b]的最大后缀,[c,d]的最大前缀。

    这三个和就是对于mid的最大的和了,可以进行判断。

    因为多组询问,而数组不会改变,

    而离散化之后,中位数的值在1~n之间。

    所以,对于每一个二分的mid值,建一棵线段树。

    线段树以区间下标为下标,记录区间和,区间最大后缀,最大前缀。

    就可以O(logn)判断mid是否可以更优了。

    空间又炸了。所以主席树闪亮登场!!!

    发现,对于mid变成mid+1,只有值为mid的数的值会从+1变成-1.

    主席树在前者的基础上暴力修改。

    每一个数就会改一次,所以均摊logn空间。、

    时间复杂度:nlogn^2

    空间复杂度:nlogn

    代码:(vector 记录数字出现的位置)

    #include<bits/stdc++.h>
    using namespace std;
    const int N=20000+10;
    int n,m;
    int a[N],num[N],mem;
    int rt[N];
    int x1,x2,x3,x4;
    int li(int x){
        return lower_bound(num+1,num+mem+1,x)-num;
    }
    struct node{
        int sum,lmx,rmx;
        int lson,rson;
        bool ncl,ncr;
        #define s(x) t[x].sum
        #define ls(x) t[x].lson
        #define rs(x) t[x].rson
        #define lm(x) t[x].lmx
        #define rm(x) t[x].rmx
        #define cl(x) t[x].ncl
        #define cr(x) t[x].ncr
    }t[N*40];
    int tot;
    vector<int>pos[N];
    void pushup(int x){
        s(x)=s(ls(x))+s(rs(x));
        lm(x)=max(lm(ls(x)),s(ls(x))+lm(rs(x)));
        rm(x)=max(rm(rs(x)),s(rs(x))+rm(ls(x)));
    }
    int build(int l,int r){
        int id=++tot;
        if(l==r){
            if(li(a[l])<=1) s(id)=lm(id)=rm(id)=-1;
            else s(id)=lm(id)=rm(id)=1;
            return id;
        }
        int mid=l+r>>1;cl(id)=1;cr(id)=1;
        ls(id)=build(l,mid);rs(id)=build(mid+1,r);
        pushup(id);
        return id;
    }
    int upda(int x,int y,int l,int r,int to,int c,bool nc){
        if(!nc) {
        x=++tot;
        }
        if(l==r){
            s(x)=lm(x)=rm(x)=c;
            return x;
        }
        
        int mid=(l+r)>>1;
        if(to<=mid){
            if(!cr(x)) rs(x)=rs(y);
            if(!cl(x)){
                cl(x)=1;ls(x)=upda(x,ls(y),l,mid,to,c,0);
            }
            else{
                ls(x)=upda(ls(x),ls(y),l,mid,to,c,1);
            }
        }
        else{
            if(!cl(x)) ls(x)=ls(y);
            if(!cr(x)){
                cr(x)=1;rs(x)=upda(x,rs(y),mid+1,r,to,c,0);
            }
            else{
                rs(x)=upda(rs(x),rs(y),mid+1,r,to,c,1);
            }
        }
        pushup(x);
        return x;
    }
    int qs(int x,int l,int r,int L,int R){
        if(L<=l&&r<=R){
            return s(x);
        }
        int mid=l+r>>1;int ret=0;
        if(L<=mid) ret+=qs(ls(x),l,mid,L,R);
        if(mid<R) ret+=qs(rs(x),mid+1,r,L,R);
        return ret;
    }
    node ql(int x,int l,int r,int L,int R){
        
        if(L<=l&&r<=R){
            return t[x];
        }
        int mid=l+r>>1;
        if(L<=mid&&mid<R){
            node ret;
            node le=ql(ls(x),l,mid,L,R);
            node ri=ql(rs(x),mid+1,r,L,R);
            ret.sum=le.sum+ri.sum;
            ret.lmx=max(le.lmx,le.sum+ri.lmx);
            return ret;
        }
        else if(L<=mid){
            return ql(ls(x),l,mid,L,R);
        }
        else {
            return ql(rs(x),mid+1,r,L,R);
        }
    }
    node qr(int x,int l,int r,int L,int R){
        if(L<=l&&r<=R){
            return t[x];
        }
        int mid=l+r>>1;
        if(L<=mid&&mid<R){
            node ret;
            node le=qr(ls(x),l,mid,L,R);
            node ri=qr(rs(x),mid+1,r,L,R);
            ret.sum=le.sum+ri.sum;
            ret.rmx=max(ri.rmx,ri.sum+le.rmx);
            return ret;
        }
        else if(L<=mid){
            return qr(ls(x),l,mid,L,R);
        }
        else {
            return qr(rs(x),mid+1,r,L,R);
        }
    }
    bool che(int val){
        int sz=0;
        if(x2+1<=x3-1) sz=qs(rt[val],1,n,x2+1,x3-1);
        int sr=ql(rt[val],1,n,x3,x4).lmx;
        int sl=qr(rt[val],1,n,x1,x2).rmx;
        return (sl+sz+sr)>=0;
    }
    int main()
    {
        scanf("%d",&n);
        for(int i=1;i<=n;i++) scanf("%d",&a[i]),num[++mem]=a[i];
        sort(num+1,num+mem+1);
        mem=unique(num+1,num+mem+1)-num-1;
        for(int i=1;i<=n;i++){
            pos[li(a[i])].push_back(i);
        }
        rt[1]=build(1,n);
        for(int i=2;i<=mem;i++){
           for(int j=0;j<pos[i-1].size();j++){
               int go=pos[i-1][j];
                rt[i]=upda(rt[i],rt[i-1],1,n,go,-1,rt[i]>0);
           }
        }
        scanf("%d",&m);
        int las=0;
        
        int ch[6];
        while(m--){
            scanf("%d%d%d%d",&x1,&x2,&x3,&x4);
            ch[1]=(x1+las)%n;
            ch[2]=(x2+las)%n;
            ch[3]=(x3+las)%n;
            ch[4]=(x4+las)%n;
            sort(ch+1,ch+4+1);
            x1=ch[1]+1,x2=ch[2]+1,x3=ch[3]+1,x4=ch[4]+1;
            int l=1,r=mem;
            int ans=0;
            while(l<=r){
                int mid=l+r>>1;
                if(che(mid)) ans=mid,l=mid+1;
                else r=mid-1;
            }
            las=num[ans];
            printf("%d
    ",las);
        }
        return 0;
    }
  • 相关阅读:
    oracle的commit
    struts2 Action 接收参数的三种方法
    git -速查表
    Windows 手动创建 服务
    Linux 上 安装 composer
    Class文件解析
    Java 从数据库中查找信息导入Excel表格中
    将Java Web项目部署到远程主机上
    Java8 map和reduce
    group By 和 Union 、 Union all的用法
  • 原文地址:https://www.cnblogs.com/Miracevin/p/9368361.html
Copyright © 2011-2022 走看看