zoukankan      html  css  js  c++  java
  • 主席树总结

    之前一直以为主席树是个什么神仙玩意儿,后面看了下其实也不是很难。主席树也被称作可持久化线段树吧,这里的线段树一般是权值线段树,普通的权值线段树只能维护整个区间的权值信息,对于部分的区
    间信息不能维护。而对于每个主席树的(root[i]),它存的是(1)~(i)的区间权值信息,相当于一个权值线段树的前缀和。因为某些问题区间信息具有可减性,所以可以用主席树来维护。这个之后有例题。
    但对于每个前缀(i)都建立一颗权值线段树空间复杂度太高了,所以这里比较巧妙的一个地方就是前缀(i)与前缀(i-1)会有很多的重合部分,所以我们可以共用很多结点。那么每次插入前缀(i)时,共用的部分不>管,对于其余的结点我们新开一个结点就是了,所以每次最多就会开一条链。那么最后总的空间复杂度就很低了,为(O(nlogn+nlogn))。时间复杂度也很低,为(O(nlogn))

    部分代码

    一开始的时候,我们会建立一颗空树,之后会以此为基础来插入前缀。

    void build(int &o, int l, int r) {
        o = ++T;
        if(l == r) {
            return ;
        }
        int mid = (l + r) >> 1;
        build(ls[o], l, mid) ;
        build(rs[o], mid + 1, r) ;
    }
    

     
    之后就是对于每个前缀(i)的插入了。

    void update(int &o, int l, int r, int last, int p) {
        o = ++T;
        ls[o] = ls[last] ;
        rs[o] = rs[last] ;//共用结点
        if(l == r) {
            sum[o] = sum[last] + 1;
            return ;
        }
        int mid = (l + r) >> 1;
        if(p <= mid) update(ls[o], l, mid, ls[last], p) ;
        else update(rs[o], mid + 1, r, rs[last], p);
        pushup(o) ;
    }
    //main中
    for(int i = 1; i <= n; i++) update(rt[i], 1, n, rt[i - 1], a[i]) ;
    

     
    对于询问(以区间第k大为例):

    int query(int L, int R, int l, int r, int k) {
        if(l == r) return l;
        int s = sum[ls[R]] - sum[ls[L]] ;
        int mid = (l + r) >> 1 ;
        if(s >= k) return query(ls[L], ls[R], l, mid, k) ;
        else return query(rs[L], rs[R], mid + 1, r, k - s) ;
    }
    

    例题

    [POJ2104区间第k大](http://poj.org/problem?id=2104)
    这是主席树基本用途之一,查询的时候通过比较区间权值数量来决定进入左子树还是右子树,具体见代码吧。 代码如下:
    Code
    #include <cstdio>
    #include <cstring>
    #include <algorithm>
    #include <iostream>
    #include <cmath>
    using namespace std;
    typedef long long ll;
    const int N = 1e5 + 5, M = 5005;
    int n, m, T;
    int a[N], b[N], c[N], d[N];
    int rt[N] ;
    int sum[N * 20], ls[N * 20], rs[N * 20];
    void pushup(int o) {
        sum[o] = sum[ls[o]] + sum[rs[o]] ;
    }
    void build(int &o, int l, int r) {
        o = ++T;
        if(l == r) {
            return ;
        }
        int mid = (l + r) >> 1;
        build(ls[o], l, mid) ;
        build(rs[o], mid + 1, r) ;
    }
    void update(int &o, int l, int r, int last, int p) {
        o = ++T;
        ls[o] = ls[last] ;
        rs[o] = rs[last] ;
        if(l == r) {
            sum[o] = sum[last] + 1;
            return ;
        }
        int mid = (l + r) >> 1;
        if(p <= mid) update(ls[o], l, mid, ls[last], p) ;
        else update(rs[o], mid + 1, r, rs[last], p);
        pushup(o) ;
    }
    int query(int L, int R, int l, int r, int k) {
        if(l == r) return l;
        int s = sum[ls[R]] - sum[ls[L]] ;
        int mid = (l + r) >> 1 ;
        if(s >= k) return query(ls[L], ls[R], l, mid, k) ;
        else return query(rs[L], rs[R], mid + 1, r, k - s) ;
    }
    int main() {
        ios::sync_with_stdio(false); cin.tie(0);
        cin >> n >> m;
        for(int i = 1; i <= n; i++) cin >> a[i], b[i] = c[i] = a[i];
        sort(b + 1, b + n + 1) ;
        int D = unique(b + 1, b + n + 1) - b - 1;
        for(int i = 1; i <= n; i++) a[i] = lower_bound(b + 1, b + D + 1, a[i]) - b, d[a[i]] = c[i];
        build(rt[0], 1, D) ;
        for(int i = 1; i <= n; i++) update(rt[i], 1, D, rt[i - 1], a[i]) ;
        for(int i = 1; i <= m; i++) {
            int l, r, k ;
            cin >> l >> r >> k ;
            int ans = query(rt[l - 1], rt[r], 1, D, k) ;
            cout << d[ans] << '
    ';
        }
        return 0;
    }
    
    [洛谷模板题](https://www.luogu.org/problemnew/show/P3919)
    可以用这个加深一下对可持久化的理解吧。 代码如下:
    Code
    #include <bits/stdc++.h>
    using namespace std;
    typedef long long ll;
    const int N = 1e6 + 5 ;
    int n, m;
    int a[N];
    int rt[N], tre[N * 20], lc[N * 20], rc[N * 20];
    int T;
    void build(int &o, int l, int r) {
        o = ++T;
        if(l == r) {
            tre[o] = a[l] ;
            return ;
        }
        int mid = (l + r) >> 1;
        build(lc[o], l, mid) ;
        build(rc[o], mid + 1, r) ;
    }
    void update(int &o, int l, int r, int last, int sign, int p, int v) {
        o = ++T;
        lc[o] = lc[last];
        rc[o] = rc[last];
        if(sign == 0)
            return ;
        if(l == r) {
            tre[o] =  v;
            return ;
        }
        int mid = (l + r) >> 1;
        if(p <= mid) update(lc[o], l, mid, lc[last], sign, p, v) ;
        else update(rc[o], mid + 1, r, rc[last], sign, p, v) ;
    }
    int query(int o, int l, int r, int p) {
        if(l == r) return tre[o] ;
        int mid = (l + r) >> 1;
        if(p <= mid) return query(lc[o], l, mid, p) ;
        else return query(rc[o], mid + 1, r, p) ;
    }
    int main() {
        ios::sync_with_stdio(false); cin.tie(0);
        cin >> n >> m ;
        for(int i = 1; i <= n; i++) cin >> a[i] ;
        build(rt[0], 1, n) ;
        for(int i = 1; i <= m; i++) {
            int v, op, pos, val;
            cin >> v >> op >> pos ;
            if(op == 1) {
                cin >> val ;
                update(rt[i], 1, n, rt[v], 1, pos, val) ;
            } else {
                int ans = query(rt[v], 1, n, pos) ;
                update(rt[i], 1, n, rt[v], 0, pos, 0);
                cout << ans << '
    ';
            }
        }
        return 0;
    }
    
    [洛谷P3567](https://www.luogu.org/problemnew/show/P3567)
    题意是询问区间中是否有超过区间长度一半的数,有的话这个数是哪个。 维护一下个数,然后根据sum看是否进入左右子树来找就行了,如果最后没有则输出0。 代码如下:
    Code
    #include <cstdio>
    #include <cstring>
    #include <algorithm>
    #include <iostream>
    #include <cmath>
    #include <map>
    using namespace std;
    typedef long long ll;
    const int N = 5e5 + 5;
    int n, m, T;
    int a[N];
    int rt[N] ;
    int sum[N * 40], ls[N * 40], rs[N * 40];
    void pushup(int o) {
        sum[o] = sum[ls[o]] + sum[rs[o]] ;
    }
    void build(int &o, int l, int r) {
        o = ++T;
        if(l == r) {
            return ;
        }
        int mid = (l + r) >> 1;
        build(ls[o], l, mid) ;
        build(rs[o], mid + 1, r) ;
    }
    void update(int &o, int l, int r, int last, int p) {
        o = ++T;
        ls[o] = ls[last] ;
        rs[o] = rs[last] ;
        if(l == r) {
            sum[o] = sum[last] + 1;
            return ;
        }
        int mid = (l + r) >> 1;
        if(p <= mid) update(ls[o], l, mid, ls[last], p) ;
        else update(rs[o], mid + 1, r, rs[last], p);
        pushup(o) ;
    }
    int query(int L, int R, int l, int r, int k) {
        if(l == r) return l;
        int lsize = sum[ls[R]] - sum[ls[L]], rsize = sum[rs[R]] - sum[rs[L]];
        int mid = (l + r) >> 1 ;
        if(lsize >= k) return query(ls[L], ls[R], l, mid, k) ;
        else if(rsize >= k) return query(rs[L], rs[R], mid + 1, r, k) ;
        else return 0;
    }
    int main() {
        ios::sync_with_stdio(false); cin.tie(0);
        cin >> n >> m;
        for(int i = 1; i <= n; i++) cin >> a[i];
        build(rt[0], 1, n) ;
        for(int i = 1; i <= n; i++) update(rt[i], 1, n, rt[i - 1], a[i]) ;
        for(int i = 1; i <= m; i++) {
            int l, r;
            cin >> l >> r;
            int mid = ((r - l + 1) >> 1) + 1;
            int v = query(rt[l - 1], rt[r], 1, n, mid) ;
            cout << v << '
    ';
        }
        return 0;
    }
    
    [南昌邀请赛网络赛 J. Distance on the tree](https://nanti.jisuanke.com/t/38229)
    题目给出一棵树,然后每条边都有一定的边权。然后有多个询问,对于每个询问给出$u,v,k$,回答从$u$到$v$的路径中,权值不大于$k$的边有多少条。

    这个题的做法还是挺多的,可以树剖,也可以直接用树状数组,我就说说主席树的做法吧。
    首先因为是树上的路径,那么我们肯定是要求LCA的。然后我们对整颗数dfs一遍,在dfs过程中插入主席树,以当前结点的父亲结点为历史版本,那么这样我们就清晰地知道从根到当前结点路径的权值信息了。
    对于每次询问,找到LCA,计算一波就行了。这里直接利用主席树找权值不超过k的个数,对于结点(u)(cnt[u]),那么最终的答案就为(cnt[u]+cnt[v]-2*cnt[LCA])
    代码如下:

    Code
    #include <bits/stdc++.h>
    using namespace std;
    typedef long long ll;
    const int N = 1e5 + 5;
    int n, m, T, D;
    int b[N << 1];
    int head[N];
    struct Edge{
        int v, w, next;
    }e[N << 1];
    struct Q{
        int u, v, w;
    }q[N];
    struct edge{
        int u, v, w;
    }E[N << 1];
    int tot;
    void adde(int u, int v, int w) {
        e[tot].v = v; e[tot].next = head[u]; e[tot].w = w; head[u] = tot++;
        e[tot].v = u; e[tot].next = head[v]; e[tot].w = w; head[v] = tot++;
    }
    int f[N][22], deep[N];
    int ls[N * 20], rs[N * 20], rt[N], sum[N * 20];
    void build(int &o, int l, int r) {
        o = ++T;
        if(l == r) return ;
        int mid = (l + r) >> 1;
        build(ls[o], l, mid) ;
        build(rs[o], mid + 1, r) ;
    }
    void update(int &o, int l, int r, int last, int v) {
        o = ++T;
        sum[o] = sum[last] + 1;
        ls[o] = ls[last]; rs[o] = rs[last] ;
        if(l == r) return ;
        int mid = (l + r) >> 1;
        if(v <= mid) update(ls[o], l, mid, ls[last], v) ;
        else update(rs[o], mid + 1, r, rs[last], v) ;
    }
    void dfs(int u, int fa) {
        for(int i = head[u]; i != -1; i = e[i].next) {
            int v = e[i].v;
            if(v == fa) continue ;
            deep[v] = deep[u] + 1;
            f[v][0] = u;
            for(int j = 1; j <= 20; j++) f[v][j] = f[f[v][j - 1]][j - 1] ;
            int w = lower_bound(b + 1, b + D + 1, e[i].w) - b;
            update(rt[v], 1, n + m, rt[u], w) ;
            dfs(v, u);
        }
    }
    int lca(int x, int y) {
        if(deep[x] < deep[y]) swap(x, y) ;
        for(int i = 20; i >= 0; i--)
            if(deep[f[x][i]] >= deep[y]) x = f[x][i] ;
        if(x == y) return x;
        for(int i = 20; i >= 0; i--)
            if(f[x][i] != f[y][i]) x = f[x][i], y = f[y][i] ;
        return f[x][0] ;
    }
    int query(int o, int l, int r, int last, int v) {
        if(l == r) return sum[o] - sum[last] ;
        int mid = (l + r) >> 1;
        int s = sum[ls[o]] - sum[ls[last]] ;
        if(v <= mid) return query(ls[o], l, mid, ls[last], v) ;
        else return s + query(rs[o], mid + 1, r, rs[last], v) ;
    }
    int main() {
        ios::sync_with_stdio(false); cin.tie(0);
        memset(head, -1, sizeof(head)) ;
        cin >> n >> m;
        for(int i = 1; i < n; i++) {
            int u, v, w;
            cin >> u >> v >> w ;
            E[i] = edge{u, v, w} ;
            b[i] = w ;
        }
        D = n;
        for(int i = 1; i <= m; i++) {
            int u, v, w;
            cin >> u >> v >> w ;
            q[i] = Q{u, v, w} ;
            b[D++] = w;
        }
        sort(b + 1, b + D) ;
        D = unique(b + 1, b + D) - b - 1;
        for(int i = 1; i < n; i++) {
            adde(E[i].u, E[i].v, E[i].w) ;
        }
        build(rt[1], 1, n + m) ;
        dfs(1,0) ;
        for(int i = 1; i <= m; i++) {
            int u = q[i].u, v = q[i].v, w = q[i].w;
            w = lower_bound(b + 1, b + D + 1, w) - b;
            int LCA = lca(q[i].u, q[i].v) ;
            int s1 = query(rt[u], 1, n + m, rt[1], w) ;
            int s2 = query(rt[v], 1, n + m, rt[1], w) ;
            int s3 = query(rt[LCA], 1, n + m, rt[1], w) ;
            //cout << w << ' ' << s1 << ' ' << s2 << ' ' << s3 << '
    ';
            cout << s1 + s2 - 2 * s3 << '
    ';
        }
        return 0 ;
    }
    
    
  • 相关阅读:
    python的整除,除法和取模对比
    jq禁用双击事件
    jq判断滑动方向
    jq获取下拉框中的value值
    html字符串转换成纯文字
    内层div相对于外层div水平垂直居中以及外层div相对body水平垂直居中
    python获取用户输入
    js判断浏览器是否支持localStorage
    CLR的执行模型
    行人检测2(行人检测的发展历史)
  • 原文地址:https://www.cnblogs.com/heyuhhh/p/10753168.html
Copyright © 2011-2022 走看看