zoukankan      html  css  js  c++  java
  • 主席树,喵~

    稍微总结一下主席树吧

    Too Difficult!搞了一天搞出一大堆怎么令人悲伤的辣鸡代码。总之先总结一下吧,以后碰到这种问题直接拿去毒害队友好了。

    UPD 5/24 苟狗是沙比


    一个节点记录三个信息:lson,rson,sum

    pid表示节点个数。

    build

    void build(int &k,int l,int r){
        k=++pid;
        if(l==r) return;
        int mid=(l+r)>>1;
        build(lson[k],l,mid);
        build(rson[k],mid+1,r);
    }
    

    change

    void change(int old,int &k,int l,int r,int pos,int x){
        k=++pid;
        lson[k]=lson[old],rson[k]=rson[old],sum[k]=sum[old]+x;
        if(l==r) return;
        int mid=(l+r)>>1;
        if(pos<=mid) change(lson[old],lson[k],l,mid,pos,x);
        else change(rson[old],rson[k],mid+1,r,pos,x);
    }
    

    Lv.1 最基本的操作

    • 区间k大值
    int query(int old_k,int new_k,int l,int r,int x){
        if(l==r) return sum[new_k]-sum[old_k]>0?l:-1;
        int mid=(l+r)>>1;
        int cntLeft = sum[lson[new_k]]-sum[lson[old_k]];
        if (cntLeft<x) {
            return query(rson[old_k],rson[new_k],mid+1,r,x-cntLeft);
        } else {
            return query(lson[old_k],lson[new_k],l,mid,x);
        }
    }
    
    
    • 区间内有多少个数字小于等于x
    int query(int new_k,int old_k,int l,int r,int x) { // cnt <= x
        if(x<l) return 0; // 这个地方比较喜。小心点。
        if(l==r) return sum[new_k]-sum[old_k];
        int mid=(l+r)>>1;
        if(mid<x) return sum[lson[new_k]]-sum[lson[old_k]]+query(rson[new_k],rson[old_k],mid+1,r,x);
        else return query(lson[new_k],lson[old_k],l,mid,x);
    }
    
    • 查询区间<=x的最大数字:上两条的组合技。

    两道入门题:POJ2104HDU4417

    主席树相当于对每一个前缀都维护一个线段树,然后发现相邻两棵线段树长得好像哎!所以我们可以动态开点啦!

    解决问题的时候,我们通常会对每一个前缀,维护一个权值线段树。每个值域存要维护的信息。


    既然是维护每一个前缀,所以,我们不仅能拿主席树来施展线性结构,还能施展树状结构!比如说我们可以查询树上两点间路径点权的k小值。

    Lv.2 树上路径上点权k小值

    栗子:SPOJ-COT

    线性结构上

    iterval(l,r)=T(r)-T(l-1)

    树状结构上

    path(u,v) = T(u)+T(v)-T(lca)-T(Parent of lca)


    Lv.2 矩形内有多少个点

    给出很多个点。Q组询问,每组询问查询一个矩形内有几个点。

    按横坐标排序,把纵坐标放到主席树上,然后就相当于区间内有多少个数字小于等于x啦!

    栗子:CF853C

    把细节考虑好!还是很友好的。

    #include <iostream>
    #include <algorithm>
    using namespace std;
    const int N=6000000+10;
    #define f(x) (1LL*x*(x-1)/2)
    typedef long long LL;
    int lson[N],rson[N],sum[N],root[N],pid;
    int n,q,p[N];
    void build(int &k,int l,int r){
        k=++pid;
        if(l==r) return;
        int mid=(l+r)>>1;
        build(lson[k],l,mid);
        build(rson[k],mid+1,r);
    }
    void change(int old,int &k,int l,int r,int pos,int x) {
        k=++pid;i
        lson[k]=lson[old],rson[k]=rson[old],sum[k]=sum[old]+x;
        if(l==r) return;
        int mid=(l+r)>>1;
        if(pos<=mid) change(lson[k],lson[k],l,mid,pos,x);
        else change(rson[k],rson[k],mid+1,r,pos,x);
    }
    int query(int new_k,int old_k,int l,int r,int x) { // cnt <= x
        if(x<l) return 0;
        if(l==r) return sum[new_k]-sum[old_k];
        int mid=(l+r)>>1;
        if(mid<x) return sum[lson[new_k]]-sum[lson[old_k]]+query(rson[new_k],rson[old_k],mid+1,r,x);
        else return query(lson[new_k],lson[old_k],l,mid,x);
    }
    int count(int x1,int x2,int y1,int y2) { // 
        if(x1>x2||y1>y2) return 0;
        int cnt1 = query(root[x2],root[x1-1],1,n,y1-1);
        int cnt2 = query(root[x2],root[x1-1],1,n,y2);
        return cnt2-cnt1;
    } 
    int main(){
        scanf("%d%d",&n,&q);
        for(int i=1;i<=n;i++) {
            scanf("%d",&p[i]);
        }
        build(root[0],1,n);
        for(int i=1;i<=n;i++) {
            change(root[i-1],root[i],1,n,p[i],1);
        }
        for(int i=1;i<=q;i++){
            int l,d,r,u;
            scanf("%d%d%d%d",&l,&d,&r,&u);
            int LU = count(1,l-1,u+1,n);
            int LD = count(1,l-1,1,d-1);
            int RU = count(r+1,n,u+1,n);
            int RD = count(r+1,n,1,d-1);
    
            int L = l-1; int U = n-u; 
            int R = n-r; int D = d-1;
            
            LL A = f(L)+f(R)+f(U)+f(D);
            LL B = f(LU)+f(LD)+f(RU)+f(RD);
            LL ret = 1LL*n*(n-1)/2-(A-B);
            printf("%lld
    ", ret);
        }
    }
    
    

    Lv.2 区间内出现数字的个数

    权值线段树直接投降了,不过我们可以在某个元素上一次出现的位置insert -1,在当前出现的位置insert 1

    种树之前想清楚该维护什么啊!

    栗子: HDU5919

    题解:因为是统计区间内,每个数字第一次出现的位置。

    所以我们可以倒着做。从后往前遍历,遇到一个数字,在这个数字上一次出现的位置加上-1,当前位置加上1.

    在从后往前遍历的同时,我们对于每一个后缀建一棵线段树。维护后缀中,每个元素第一次出现的位置。

    对于每组询问,先求出区间内有多少种不同的数字,然后查询第(cnt+1)/2大即可。

    #include <iostream>
    #include <map>
    using namespace std;
    const int N = 10000000+10;
    int lson[N],rson[N],root[N],sum[N],pid;
    int T,cas;
    
    void build(int &k,int l,int r) {
        k=++pid;
        if(l==r) return;
        int mid=(l+r)>>1;
        build(lson[k],l,mid);
        build(rson[k],mid+1,r);
    }
    void update(int old,int &k,int l,int r,int pos,int x) {
        k=++pid; sum[k] = 0;
        lson[k]=lson[old], rson[k]=rson[old], sum[k]=sum[old]+x;
        if(l==r) return;
        int mid=(l+r)>>1;
        if (pos<=mid) 
            update(lson[old],lson[k],l,mid,pos,x);
        else
            update(rson[old],rson[k],mid+1,r,pos,x);
    }
    int query_x_th(int k,int l,int r,int x) {
        if (l == r) 
            return l;
        int mid = (l+r)>>1;
        if (sum[lson[k]] < x) {
            return query_x_th(rson[k],mid+1,r,x-sum[lson[k]]);
        } else {
            return query_x_th(lson[k],l,mid,x);
        }
    }
    int count(int k,int l,int r,int L,int R) {
        if(L<=l&&r<=R) {
            return sum[k];
        }
        int mid = (l+r)>>1;
        int ans = 0;
        if (L<=mid) ans += count(lson[k],l,mid,L,R);
        if (R >mid) ans += count(rson[k],mid+1,r,L,R);
        return ans;
    }
    
    int n, m, a[N];
    map<int,int> las;
    void init() {
        las.clear();
        pid = 0;
    }
    int main() {
        scanf("%d",&T);
        while (T --) {
            init();
            scanf("%d %d",&n,&m);
            for(int i=1;i<=n;i++) {
                scanf("%d", &a[i]); 
            }
    
            build(root[n+1],1,n);
            for(int i=n;i>=1;i--) {
                update(root[i+1],root[i],1,n,i,1);
                
                if ( las.find(a[i]) != las.end() )
                    update(root[i],root[i],1,n,las[a[i]], -1);
                
                las[a[i]] = i;
            }
    
            printf("Case #%d:", ++cas);
            int ans=0;
            for(int i=1;i<=m;i++) {
                int l, r;
                scanf("%d %d", &l, &r);
                int nl = min((l+ans)%n+1, (r+ans)%n+1);
                int nr = max((l+ans)%n+1, (r+ans)%n+1);
                int tot = count(root[nl],1,n,nl,nr);
                ans = query_x_th(root[nl],1,n,(tot+1)/2);
                printf(" %d", ans);
            }
            printf("
    ");
    
        }
    }
    
    
    

    Lv.3 主席树的区间更新

    一种不用下传懒惰标记的姿势:对于区间查询,从上往下走的时候,对懒惰标记进行累加。

    栗子:HDU4348

    #include <iostream>
    #include <algorithm>
    #include <vector>
    using namespace std;
    typedef long long LL;
    const int N=6000000+10;
    int lson[N],rson[N],root[N],pid;
    LL sum[N],lazy[N];
    int n,q,a[N];
    
    void build(int &k,int l,int r){
        k=++pid; lazy[k] = 0; sum[k] = 0;
        if(l==r) {
            sum[k] = a[l];
            lson[k] = rson[k] = 0;
            return;
        }
        int mid=(l+r)>>1;
        build(lson[k],l,mid);
        build(rson[k],mid+1,r);
        sum[k] = sum[lson[k]] + sum[rson[k]];
    }
    void update(int old,int &k,int l,int r,int L,int R,int x){
        k=++pid; lazy[k] = 0; sum[k] = 0;
        lazy[k]=lazy[old]; sum[k] = sum[old];
        lson[k]=lson[old]; rson[k]=rson[old];
    
        if(L<=l&&r<=R) {
            lazy[k] = lazy[old] + x;
            sum[k] = sum[old] + 1LL*(r-l+1)*x;
            return;
        }
        int mid=(l+r)>>1;
        if (L<=mid)
            update(lson[k],lson[k],l,mid,L,R,x);
        if (R >mid)
            update(rson[k],rson[k],mid+1,r,L,R,x);
    
        sum[k] = sum[lson[k]] + sum[rson[k]] + 1LL*lazy[k]*(r-l+1);
    }
    LL query(int k,int l,int r,int add,int L,int R) {
        if (L<=l&&r<=R)
            return sum[k] + 1LL*(r-l+1)*add;
    
        add += lazy[k];
        int mid=(l+r)>>1;
        LL ans=0;
        if (L<=mid) ans += query(lson[k],l,mid,add,L,R);
        if (R >mid) ans += query(rson[k],mid+1,r,add,L,R);
        return ans;
    }
    int stamp = 0;
    void init() {
        stamp=0;
        pid=0;
    }
    int main(){
        while (~ scanf("%d%d",&n,&q)) {
            init();
            for(int i=1;i<=n;i++) scanf("%d",&a[i]);
            build(root[0],1,n);
            int id = 0;
            for(int i=1;i<=q;i++){
                char op[2]; int l,r,t;
                scanf("%s",op);
                if(op[0] == 'C') {
                    scanf("%d%d%d",&l,&r,&t);
                    update(root[stamp],root[stamp+1],1,n,l,r,t);
                    stamp ++;
                }
                if(op[0] == 'Q') {
                    scanf("%d%d",&l,&r);
                    LL ans = query(root[stamp],1,n,0,l,r); 
                    printf("%lld
    ", ans);
                }
                if(op[0] == 'H') { 
                    scanf("%d%d%d",&l,&r,&t);
                    LL ans = query(root[t],1,n,0,l,r);
                    printf("%lld
    ", ans);
                }
                if(op[0] == 'B'){
                    scanf("%d",&t);
                    stamp = t;
                }
            }
        }
    }
    

    一些练习

    CF650D

    题意:动态LIS,每次修改一个位置,每次操作查询LIS,操作相互独立

    题解:

    两种情况

    第一种,更新后pos,出现在了LIS中

    我们要做的是:查询[1,pos)中,h<h[pos]的所有数字,LISmax

    可以对每一个前缀维护一个h的权值线段树,每个节点记录h在此值域内LISmax

    第二种,更新后pos,没出现在LIS中

    判断一下pos是否在存在于所有的,原序列LIS中。

    这个地方很有趣。

    hint: dp[i]+rev_dp[i]=LIS+1

    Bonus:
    1. 存在一个LIS包含元素i的条件
    2. 所有LIS包含元素i的条件
    
    #include <iostream>
    #include <cmath>
    #include <cstring>
    #include <algorithm>
    #include <vector>
    using namespace std;
    
    const int N = 400000+10;
    const int INF = 1000000007;
    
    int bit[N];
    vector<int> v;
    int id(int x) {
        return lower_bound(v.begin(),v.end(),x)-v.begin()+1;
    }
    int get(int x) {
        int ans=0;
        while(x) {
            ans=max(ans,bit[x]);
            x-=x&-x;
        }
        return ans;
    }
    void upd(int pos,int x){
        while(pos<N) {
            bit[pos]=max(bit[pos],x);
            pos += pos&-pos;
        }
    }
    int n,m,h[N],dp[N],rdp[N],neccesary[N];
    int LIS=0;
    vector<int> pos[N];
    void compress(int on) {
        v.clear();
        if (on == 0) {
            for(int i=1;i<=n;i++) v.push_back(h[i]);
        } else {
            for(int i=1;i<=n;i++) v.push_back(INF-h[i]);
        }
        sort(v.begin(), v.end());
        v.erase(unique(v.begin(),v.end()),v.end());
    }
    void LIS_Proccess() {
        scanf("%d%d",&n,&m);
        for(int i=1;i<=n;i++) {
            scanf("%d",&h[i]);
        }
        compress(0);
    
        for(int i=1;i<=n;i++) {
            dp[i] = get(id(h[i])-1) + 1;
            upd(id(h[i]), dp[i]);
            LIS = max(LIS, dp[i]);
        }
    
        memset(bit,0,sizeof(bit));
        compress(1);
        for(int i=n;i>=1;i--) {
            rdp[i] = get(id(INF-h[i])-1) + 1;
            upd(id(INF-h[i]), rdp[i]);
        }
    
        for(int i=1;i<=n;i++) {
            if (dp[i]+rdp[i] == LIS+1) {
                pos[dp[i]].push_back(i);
            }
        }
    
        for(int i=1;i<=n;i++) {
            if (pos[i].size() == 1) {
                neccesary[pos[i][0]] = 1;
            }
        }
    
    }
    
    int lson[N*22],rson[N*22],val[N*22],root[N*22],pid;
    int ans[N], pre[N], suf[N], p[N], x[N];
    void build(int &k,int l,int r) {
        k=++pid; val[k]=0;
        if(l==r) return;
        int mid=(l+r)>>1;
        build(lson[k],l,mid);
        build(rson[k],mid+1,r);
    }
    void change(int old,int &k,int l,int r,int pos,int x) {
        k=++pid;
        lson[k]=lson[old],rson[k]=rson[old],val[k]=max(x,val[old]);
    
        if(l==r) return;
        int mid=(l+r)>>1;
        if(pos<=mid) change(lson[old],lson[k],l,mid,pos,x);
        else change(rson[old],rson[k],mid+1,r,pos,x);
    }
    int query(int k,int l,int r,int L,int R) {
        if(L>R) return 0;
        if(L<=l&&r<=R) {
    
            return val[k];
        }
        int mid=(l+r)>>1;
        int ans=0;
        if (L<=mid) ans=max(ans, query(lson[k],l,mid,L,R));
        if (R >mid) ans=max(ans, query(rson[k],mid+1,r,L,R));
        return ans;
    }
    
    int main() {
    
        LIS_Proccess();
    
        // neccesary[i]: 第i位一定出现在LIS中
        pid=0; compress(0);
        build(root[0],1,v.size());
        for(int i=1;i<=n;i++) {
            change(root[i-1],root[i],1,v.size(),id(h[i]),dp[i]);
        }
    
        for(int i=1;i<=m;i++) {
            scanf("%d%d",&p[i],&x[i]);
            ans[i] = neccesary[p[i]] ? LIS - 1 : LIS;
            pre[i] = query(root[p[i]-1], 1, v.size(), 1, id(x[i])-1);
        }
        //exit(0);
    
        pid=0; compress(1);
        build(root[n+1],1,v.size());
        for(int i=n;i>=1;i--) {
            change(root[i+1],root[i],1,v.size(),id(INF-h[i]),rdp[i]);
        }
        for(int i=1;i<=m;i++) {
            suf[i] = query(root[p[i]+1], 1, v.size(), 1, id(INF-x[i])-1);
            ans[i] = max(ans[i], pre[i]+suf[i]+1);
            printf("%d
    ", ans[i]);
        }
    
    }
    
    

    以上,于4/28,mark一下。

    之后,待补的坑:

    • BIT套主席树 【学不会】
    • 主席树的区间更新【已补】

    学数据结构是不可能学数据结构的,这辈子都不可能学数据结构!

  • 相关阅读:
    awk 正则匹配指定字段次数统计
    base64图片内容下载转为图片保存
    基于keras的fasttext短文本分类
    ubuntu 更换为mac主题
    ubuntu crontab python 定时任务备记
    ubuntu14.04 安装jdk1.8及以上
    fastext 中文文本分类
    django 多线程下载图片
    中文词向量训练
    mongodb 安装使用备记
  • 原文地址:https://www.cnblogs.com/RUSH-D-CAT/p/8965601.html
Copyright © 2011-2022 走看看