zoukankan      html  css  js  c++  java
  • 可持久化线段树

    可持久化线段树

    简介

    可持久化数据结构又称函数式数据结构,其思路来自于函数式编程。在函数式编程中,变量的值是不允许改变的,因而每一次插入元素都必须创建一个新的版本。

    设想一棵二叉树:

            [1]
        [2]     [3]
     [4]  [5] [6]  

    现在为了插入一个新节点,我们必须新建一棵树

            (1)
        (2)     (3)
     (4)  (5) (6)  (7)

    不难发现,很多元素被重复使用了。如果将重复的元素合并,就得到这样一棵树:

            [1]    --->  (1)
        [2]     [3]   [2]   (3)
     [4]  [5] [6]         [6]  (7)

    新建的元素其实只有O(h),如果是一棵平衡树或线段树,新建元素就是O(lgn)

    应用

    可持久化线段树是解决区间问题的锐利武器。考虑第i棵和第j棵线段树Ti,Tj,如果他们的对应元素相减得到一棵新树TjTi,这棵树其实就是区间 [i+1,j] 所对应的线段树。

    例如vijos1459车展一题。用反证法不难证明题目中要求的即是

    i=lr|ximid|

    其中,mid为区间 x[l,r] 的中位数。

    由于涉及了区间中位数,可以考虑使用树套树实现。但树套树代码复杂度较高且不宜于调试,可以考虑用可持久化线段树代替。

    将输入的 xi 按顺序建立一棵可持久化线段树,分别维护sumnum_sum,第一个为区间内元素的和,第二个为区间内元素出现的次数。利用 TrTl1 得到区间 [l,r] 内的线段树来计算。

    Code

    // 可持久化线段树
    // 维护两个值
    #include <bits/stdc++.h>
    using namespace std;
    
    #define maxn 1005
    struct node {
        int l, r, lc, rc;
        long long sum;
        int num_sum;
        node(){l = r = lc = rc = sum = num_sum = 0; }
    }tree[15*maxn];
    int root[200005], top = 0;
    int n, m;
    
    inline long long read()
    {
        long long a = 0; int c;
        do c = getchar(); while(!isdigit(c));
        while (isdigit(c)) {
            a = a*10 + c - '0';
            c = getchar();
        }
        return a;
    }
    
    int sorted[1005]; // 离散化
    int dat[1005]; // 原始数据
    
    inline void update(int i) {
        tree[i].sum = tree[tree[i].lc].sum + tree[tree[i].rc].sum;
        tree[i].num_sum = tree[tree[i].lc].num_sum + tree[tree[i].rc].num_sum;
    }
    
    inline int new_node(int l, int r) {
        tree[++top].l = l;
        tree[top].r = r;
        return top;
    }
    
    void build(int &nd, int l, int r) {
        if (l > r) return;
        if (l == r) {nd = new_node(l, r);return;}
        int mid = (l+r)>>1;
        nd = new_node(l, r);
        build(tree[nd].lc, l, mid);
        build(tree[nd].rc, mid+1, r);
    }
    
    void insert(int pre, int &now, int k, long long dat) {
        if (tree[pre].l == tree[pre].r) {
            now = new_node(k, k);
            tree[now].sum = dat;
            tree[now].num_sum = 1;
        } else {
            now = new_node(tree[pre].l, tree[pre].r);
            tree[now] = tree[pre];
            if (k <= tree[tree[pre].lc].r) insert(tree[pre].lc, tree[now].lc, k, dat);
            else insert(tree[pre].rc, tree[now].rc, k, dat);
            update(now);
        }
    }
    
    // 查找区间和(sum)
    long long get_sum(int pre, int now, int l, int r)
    {
        if (l > r || !pre || !now) return 0;
        if (l == tree[pre].l && r == tree[now].r) return tree[now].sum - tree[pre].sum;
        return get_sum(tree[pre].lc, tree[now].lc, l, min(r, tree[tree[pre].lc].r)) +
               get_sum(tree[pre].rc, tree[now].rc, max(tree[tree[pre].rc].l, l), r);
    }
    
    // 区间内数字个数的和
    int get_num_sum(int pre, int now, int l, int r)
    {
        if (l > r || !pre || !now) return 0;
        if (l == tree[pre].l && r == tree[now].r) return tree[now].num_sum - tree[pre].num_sum;
        return get_num_sum(tree[pre].lc, tree[now].lc, l, min(r, tree[tree[pre].lc].r)) +
               get_num_sum(tree[pre].rc, tree[now].rc, max(tree[tree[pre].rc].l, l), r);
    }
    
    int find_mid(int pre, int now, int k) // 查找中位数(第k个数)的位置
    {
        if (tree[now].l == tree[now].r) return tree[now].l;
        if (tree[tree[now].lc].num_sum - tree[tree[pre].lc].num_sum >= k)
            return find_mid(tree[pre].lc, tree[now].lc, k);
        else
            return find_mid(tree[pre].rc, tree[now].rc, k-(tree[tree[now].lc].num_sum - tree[tree[pre].lc].num_sum));
    }
    
    // 查询区间
    long long query(int l, int r) {
        int pos = find_mid(root[l-1], root[r], ((l+r)>>1)-l+1);
        long long lft = get_sum(root[l-1], root[r], 1, pos);int ln = get_num_sum(root[l-1], root[r], 1, pos);
        long long rgt = get_sum(root[l-1], root[r], pos+1, n);int rn = get_num_sum(root[l-1], root[r], pos+1, n);
        return rgt - rn*sorted[pos] + ln*sorted[pos] - lft;
    }
    
    void dfs(int rt, int tab = 0) {
        if (rt) {
            for (size_t i = 0; i < tab; i++) putchar(' ');
            cout << tree[rt].l << "->" << tree[rt].r << " " << tree[rt].sum << " " << tree[rt].num_sum << endl;
            dfs(tree[rt].lc, tab+2);
            dfs(tree[rt].rc, tab+2);
        }
    }
    
    int main()
    {
        n = read(); m = read();
        build(root[0], 1, n);
        long long a, l, r;
        for (size_t i = 1; i <= n; i++)
            sorted[i] = dat[i] = read();
        sort(sorted+1, sorted+n+1);
        for (size_t i = 1; i <= n; i++) {
            insert(root[i-1], root[i], lower_bound(sorted+1, sorted+n+1, dat[i])-sorted, dat[i]);
        }
        long long ans = 0;
        for (size_t i = 1; i <= m; i++) {
            l = read(); r = read();
            ans += query(l, r);
        }
        cout << ans << endl;
        return 0;
    }
  • 相关阅读:
    __str__
    __call__
    私有成员
    @property
    静态方法
    静态字段
    cut qcut
    hive 函数大全
    sklearn 中的Countvectorizer/TfidfVectorizer保留长度小于2的字符方法
    numpy教程:随机数模块numpy.random
  • 原文地址:https://www.cnblogs.com/ljt12138/p/6684357.html
Copyright © 2011-2022 走看看