zoukankan      html  css  js  c++  java
  • 学习记录:线段树

    线段树 V1.0

    之所以叫v1.0呢,是因为这是我第一次学这个数据结构。
    考虑到重要性,以后在做题的过程中会对这篇博客做更新的。

    概念

    线段树是一种二叉搜索树,用于处理区间问题的数据结构。

    与ST表不同的是,线段树支持点,区间修改。相应的,虽然预处理速度与ST表相同,都是(O(logn)),但查询速度比起ST表的(O(1))要慢,是(O(logn))

    线段树是建立在区间二分这个概念上的,树上的每个节点都代表了一段区间。
    如图

    Nb1dmR.jpg

    1. 对于每个区间([L,R])而言,都有一个左端点(L)和右端点(R)
    2. (L=R)时,当前所指区间是一个点。显然,一个点是不能继续拆分的,所以这是一个叶子节点。
      反过来考虑,(L eq R)时,这一段区间必定包括了两个或以上的点,因此必有两个叶节点。
      综上,线段树是没有只有一个子节点的节点的。
    3. (L eq R)时,区间必然可以拆分为两个小区间。这里先设(M=(L+R)/2)
      左子节点的范围是([L,M]),相应的,右子节点的范围是([M+1,R])

    对于二叉树这种结构,一般都用的是递归的方式。用指针难免会比较难处理,所以可以用完全二叉树的数组储存的方式,将线段树存放到数组里。
    对于上图的线段树,用数组储存后的表现是

    NbtqjU.jpg

    这样储存,大概率会需要比较多的空间。一般来说,有n个点时需要4n的空间(2 imes 2^k(2^{k-1}<n<2^k))

    如果学过完全二叉树,那么父子节点的关系就很清楚。设父节点下标为K,则有

    • L=K*2(左节点)
    • R=K*2+1(右节点)

    因为父子节点的关系有2倍关系,经常会用位运算的方式来计算下标,如

    • L=K<<1(向左移一位,相当于*2)
    • R=k<<1|1 (向左移一位,再加上1,相当于*2+1)

    代码实现

    创建线段树

    既然是一种二叉树的结构,一般用递归来做会比较简单。

    const int maxn = 1e2 + 10;
    int a[maxn] = {0, 1, 2, 3, 4, 5, 6, 7};//原数组
    int tree[maxn * 4];//需要建树的
    void print(int n)//输出tree的函数,这个自己随便写写,方便看就行
    {
        for (int i = 1; i < n * 4; i++)
        {
            if ((i & (i - 1)) == 0)
                cout << endl;
            cout << setw(4) << tree[i];
        }
    }
    void Pushup(int k)//更新函数	k:线段树节点下标
    {
        tree[k] = max(tree[k * 2], tree[k * 2 + 1]);//这里以最大值为例,这句话视题目意思而定
    }
    void Build(int l, int r, int k)//建树函数	l:原数组a的左端点	r:原数组a的右端点	k:当前线段树节点下标
    {
        //比如该例中a要建树的范围是1~7,那么l=1,r=7
        //k默认选1.不要选0!0*2=0,失去了找子节点的功能
        //一开始建树的时候,k指的就是根节点所在下标
        if (l == r)//左右端点相等,说明现在是一个点,直接把原数组的东西复制过来
            tree[k] = a[l];
        else//否则就肯定是一段区间
        {
            int m = (l + r) / 2;//确定中点
            Build(l, m, k * 2);//递归建左子树
            Build(m + 1, r, k * 2 + 1);//递归建右子树
            //这两句位置变动没有影响,不过要注意范围和k的值
            Pushup(k);//更新当前节点
        }
    }
    int main()
    {
        Build(1, 7, 1);//a数组下标1~7的建树,tree数组从1开始
        print(7);
        return 0;
    }
    

    结果如图:
    PS:7下面是空的

    Nq86Bj.jpg

    点更新

    点更新很易于理解。从需要更改的根节点出发,将每一个覆盖到这个点的区间都更新一次即可。

    void updata(int p,int val,int l,int r,int k)//p:需要更改的原数组下标	val:增加的值	l:原数组的左端点	r:原数组的右端点	k:
    {
        if (l==r)//说明是单点,加上就好了
            a[p] += val, tree[k] += val;//原数组和线段树数组都加
        else{
            int m = (l + r) / 2;//中点
            if (p<=m){//要修改的点在左子树上,记得有等于号!
                updata(p, val, l, m, k * 2);
            }else{//在右子树上
                updata(p, val, m + 1, r, k * 2 + 1);
            }   
            Pushup(k);//更新当前节点
        }
    }
    

    区间查询

    也很容易理解:查询的是一段区间,我们只需要将这个区间所包含的子区间——也就是在预处理中已经算好的值都拿出来就行

    代码如下:

    int Query(int L, int R, int l, int r, int k)//L,R:要查询的区间范围	l,r:当前的区间范围	k:当前线段树下标
    {
        if (L <= l && r <= R)//当前区间完全包含在查询区间内,直接返回
            return tree[k];
        else
        {
            //可能包括了左端点,也可能有右端点
            int res = 0;//答案,注意初始化的值要随题目意思改变
            int m = (l + r) / 2;//中点
            if (L <= m)//左子树与查询区间有交集
                res = max(res, Query(L, R, l, m, k * 2));//这句话应题目意思而变,该例中是最大值
            if (R >= m+1)//右子树与查询区间有交集,注意右区间从m+1开始
                res = max(res, Query(L, R, m + 1, r, k * 2 + 1));
            return res;//返回答案
        }
    }
    

    其实查询比上面两种操作还是要难,配合图片更容易理解

    这里假设要查询2~5之间的最大值

    UkWD2V.jpg

    区间修改

    大意:指定(i,j leq n),将区间([a,b])的每个数字加c

    直接套用点修改的方式在时间复杂度上并不比直接在数组上修改好,此时要用一种“懒惰”的做法

    lazy-tag:修改整个区间时,只对这个区间进行整体性的修改,内部的每个元素则暂时不做处理。只有当这个线段区间的一致性被破坏时,才对子区间的值做修改。

    模板

    /*
    简写说明:
    cur:当前线段树下标
    l,r:要进行处理的区间
    seg:线段树数组名
    lazy:懒惰标记
    */
    #include <bits/stdc++.h>
    using namespace std;
    
    typedef long long ll;
    typedef unsigned long long ull;
    const int maxn = 1e5 + 10;
    
    ll num[maxn], seg[maxn << 2], lazy[maxn << 2];
    
    void print(int n) //输出tree的函数,这个自己随便写写,方便看就行
    {
        for (int i = 1; i < n * 4; i++)
        {
            if ((i & (i - 1)) == 0)
                cout << endl;
            cout << setw(4) << seg[i];
        }
        cout << endl;
        for (int i = 1; i < n * 4; i++)
        {
            if ((i & (i - 1)) == 0)
                cout << endl;
            cout << setw(4) << lazy[i];
        }
        cout << endl;
    }
    
    void Pushup(int cur)//向上更新函数,这里是求区间和
    {
        seg[cur] = seg[cur << 1] + seg[cur << 1 | 1];
    }
    
    void Pushdown(int cur, int l, int r)
    {
        if (lazy[cur])
        {
            int m = (l + r) >> 1;
            lazy[cur << 1] += lazy[cur];
            lazy[cur << 1 | 1] += lazy[cur];
            seg[cur << 1] += lazy[cur] * (m - l + 1);
            seg[cur << 1 | 1] += lazy[cur] * (r - m);
            lazy[cur] = 0;
        }
    }
    
    void Build(int cur, int l, int r)
    {
        if (l == r)
            seg[cur] = num[l];
        else
        {
            int m = (l + r) >> 1;
            Build(cur << 1, l, m);
            Build(cur << 1 | 1, m + 1, r);
            Pushup(cur);
        }
    }
    
    void Point(int index, int val, int l, int r, int cur)
    {
        if (l == r)
            num[index] += val, seg[cur] += val;
        else
        {
            int m = (l + r) >> 1;
            if (index <= m)
                Point(index, val, l, m, cur << 1);
            else
                Point(index, val, m + 1, r, cur << 1 | 1);
            Pushup(cur);
        }
    }
    
    void updata(int L, int R, int val, int l, int r, int cur)
    {
        if (L <= l && r <= R)
        {
            lazy[cur] += val;
            seg[cur] += val * (r - l + 1);
        }
        else
        {
            Pushdown(cur, l, r);
            int m = (l + r) >> 1;
            if (L <= m)
                updata(L, R, val, l, m, cur << 1);
            if (m < R)
                updata(L, R, val, m + 1, r, cur << 1 | 1);
            Pushup(cur);
        }
    }
    
    ll Query(int L, int R, int l, int r, int cur)
    {
        if (L <= l && r <= R)
            return seg[cur];
        else
        {
            Pushdown(cur, l, r);
            ll res = 0;
            int m = (l + r) >> 1;
            if (L <= m)
                res += Query(L, R, l, m, cur << 1);
            if (R >= m + 1)
                res += Query(L, R, m + 1, r, cur << 1 | 1);
            return res;
        }
    }
    
    int main()
    {
        int n, m;
        cin >> n >> m;
        for (int i = 1; i <= n; i++)
            cin >> num[i];
        Build(1, 1, n);
        int flag, x, y, k;
        while (m--)
        {
            cin >> flag;
            if (flag == 1)
            {
                cin >> x >> y >> k;
                updata(x, y, k, 1, n, 1);
            }
            else
            {
                cin >> x >> y;
                cout << Query(x, y, 1, n, 1) << endl;
            }
        }
        return 0;
    }
    
  • 相关阅读:
    数据库的......
    数据库
    XML
    网络编程
    I/O系统---流
    周结

    集合,框架
    Spring入门
    Java Wed
  • 原文地址:https://www.cnblogs.com/Salty-Fish/p/13261524.html
Copyright © 2011-2022 走看看