zoukankan      html  css  js  c++  java
  • c++ 线段树

    关于线段树

    线段数是一种区间树

    可以看出:叶子即为输入的数
    假设一个节点为 x ,则其左儿子为 2x 右儿子为 2x+1

    操作解析

    约定

    变量名 意义
    input[] 输入的数
    t[] 线段树
    其中 t[] 是个结构体,包含左边界 l ,右边界 r 和区间和 sum
    sum 并不是必须有的,这些维护的值需要根据题目要求增多、减少

    基本操作

    卡常必备

    左儿子与右儿子
    #define ls rt<<1
    #define rs rt<<1|1

    push_up

    这里只是更新区间和,如有更多操作还需更改
    当前区间和 = 左儿子区间和 + 右儿子区间和

    inline void push_up(int rt){
        t[rt].sum=t[ls].sum+t[rs].sum;
    }
    
    build

    作用:构建线段树
    思路:

    1. 给当前的 l 和 r 区间赋值
    2. 判断是否为叶子节点,是就把当前位置的 sum 赋为 input[l] 并返回
    3. 否则继续构建
    void build(int l,int r,int rt){
        t[rt].l=l,t[rt].r=r;
        if(l==r){
            t[rt].sum=input[l];
            return;
        }
        int mid=(l+r)>>1;
        build(l,mid,ls);
        build(mid+1,r,rs);
        push_up(rt);
    }
    

    正式开始

    单点修改

    作用:把位置 p 的值加上 k
    当然我们需要维护
    如何判断左右子树是否包含 p 呢?

    左子树的右边界大于等于 p 就算包含
    右子树的左边界小于等于 p 也是包含
    思路:

    1. 先将当前位置的 sum 加上 k
    2. 如果达到叶子,返回
    3. 判断左右子树是否包含并继续更新
    4. push_up
    void add(int p,int k,int rt){
        t[rt].sum+=k;
        if(t[rt].l==t[rt].r) return;
        if(p<=t[ls].r) add(p,k,ls);
        if(p>=t[rs].l) add(p,k,rs);
        push_up(rt);
        return;
    }
    

    区间修改(加法)

    作用:把 [l,r] 区间加上 k
    运用了懒标记思想, add 表示当前区间需要加上多少

    下传标记

    把 add 传到左右子树并更新 sum
    sum 显然就要加上区间长度乘 add

    inline void down(int rt){
        if(t[rt].add){
            t[ls].sum+=(t[ls].r-t[ls].l+1)*t[rt].add;
    	t[rs].sum+=(t[rs].r-t[rs].l+1)*t[rt].add;
    	t[ls].add+=t[rt].add;
    	t[rs].add+=t[rt].add;
    	t[rt].add=0;		
        }
    }
    
    递归修改

    思路:

    1. 如果该区间被完全包含,更新 sum 打上标记并返回
    2. down
    3. 判断左右区间是否包含并继续更新
    4. push_up
    void pls(int l,int r,int k,int rt){
        if(l<=t[rt].l&&r>=t[rt].r){
            t[rt].sum+=k*(t[rt].r-t[rt].l+1);
    	t[rt].add+=k;
    	return;
        }
        down(rt);
        if(l<=t[ls].r) pls(l,r,k,ls);
        if(r>=t[rs].l) pls(l,r,k,rs);
        push_up(rt);	
    }
    

    单点查询

    思路:

    1. 如果找到该点,返回 sum
    2. 判断左右区间是否包含并继续查找
    long long search(int p,int rt){
        if(t[rt].l==p&&t[rt].r==p)
            return t[rt].sum;
        if(p<=t[ls].r) return search(p,ls);
        if(p>=t[rs].l) return search(p,rs);
    } 
    

    区间查询

    思路:

    1. 如果区间被完全包含,返回 sum
    2. 判断左右区间是否包含并把查找的值加到 s
    3. 返回 s
    long long query(int l,int r,int rt){
        if(l<=t[rt].l&&r>=t[rt].r)
            return t[rt].sum;
        long long s=0;
        if(l<=t[ls].r) s+=query(l,r,ls);
        if(r>=t[rs].l) s+=query(l,r,rs);
        return s;
    }
    

    例题

    Warning

    1. 如果遇到需要区间修改的,查询时一定要下传标记
    2. 十年 OI 一场空,不开 long long 见祖宗

    区改 + 区查

    洛谷 P3372

    #include<bits/stdc++.h>
    #define ls rt<<1
    #define rs rt<<1|1 
    using namespace std;
    typedef long long ll;
    struct QwQ{
        int l,r;
        ll sum,add;
    }t[2000010];
    inline void push_up(int rt){
        t[rt].sum=t[ls].sum+t[rs].sum;
    }
    inline void down(int rt){
        if(t[rt].add){
    	t[ls].sum+=(t[ls].r-t[ls].l+1)*t[rt].add;
    	t[rs].sum+=(t[rs].r-t[rs].l+1)*t[rt].add;
    	t[ls].add+=t[rt].add;
    	t[rs].add+=t[rt].add;
    	t[rt].add=0;		
        }
    }
    int input[500002];
    void build(int l,int r,int rt){
        t[rt].l=l,t[rt].r=r;
        if(l==r){
    	t[rt].sum=input[l];
    	return;
        }
        int mid=l+r>>1;
        build(l,mid,ls);
        build(mid+1,r,rs);
        push_up(rt);
    }
    void pls(int l,int r,int k,int rt){
        if(l<=t[rt].l&&r>=t[rt].r){
            t[rt].sum+=k*(t[rt].r-t[rt].l+1);
    	t[rt].add+=k;
    	return;
        }
        down(rt);
        if(l<=t[ls].r) pls(l,r,k,ls);
        if(r>=t[rs].l) pls(l,r,k,rs);
        push_up(rt);	
    }
    ll query(int l,int r,int rt){
        if(l<=t[rt].l&&r>=t[rt].r)
            return t[rt].sum;
        down(rt);
        ll s=0;
        if(l<=t[ls].r) s+=query(l,r,ls);
        if(r>=t[rs].l) s+=query(l,r,rs);
        return s;
    }
    int n,m,opt,x,y,k;
    int main(){
        scanf("%d%d",&n,&m);
        for(int i=1;i<=n;i++) scanf("%d",&input[i]);
        build(1,n,1);
        while(m--){
            scanf("%d%d%d",&opt,&x,&y);
    	if(opt==1){
    	    scanf("%d",&k);
    	    pls(x,y,k,1);
    	}
    	else printf("%lld\n",query(x,y,1));
        }
    }
    

    区改 + 单查

    洛谷 P3368

    #include<bits/stdc++.h>
    #define ls rt<<1
    #define rs rt<<1|1 
    using namespace std;
    typedef long long ll;
    struct QwQ{
        int l,r;
        ll sum,add;
    }t[2000010];
    inline void push_up(int rt){
        t[rt].sum=t[ls].sum+t[rs].sum;
    }
    inline void down(int rt){
        if(t[rt].add){
            t[ls].sum+=(t[ls].r-t[ls].l+1)*t[rt].add;
    	t[rs].sum+=(t[rs].r-t[rs].l+1)*t[rt].add;
    	t[ls].add+=t[rt].add;
    	t[rs].add+=t[rt].add;
    	t[rt].add=0;		
        }
    }
    int input[500002];
    void build(int l,int r,int rt){
        t[rt].l=l,t[rt].r=r;
        if(l==r){
    	t[rt].sum=input[l];
    	return;
        }
        int mid=l+r>>1;
        build(l,mid,ls);
        build(mid+1,r,rs);
        push_up(rt);
    }
    void pls(int l,int r,int k,int rt){
        if(l<=t[rt].l&&r>=t[rt].r){
            t[rt].sum+=k*(t[rt].r-t[rt].l+1);
    	t[rt].add+=k;
    	return;
        }
        down(rt);
        if(l<=t[ls].r) pls(l,r,k,ls);
        if(r>=t[rs].l) pls(l,r,k,rs);
        push_up(rt);	
    }
    ll search(int l,int r,int rt){
        if(l<=t[rt].l&&r>=t[rt].r)
     	return t[rt].sum;
        down(rt);
        ll s=0;
        if(l<=t[ls].r) s+=search(l,r,ls);
        if(r>=t[rs].l) s+=search(l,r,rs);
        return s;
    }
    ll search(int p,int rt){
        if(t[rt].l==p&&t[rt].r==p)
            return t[rt].sum;
        down(rt);
        if(p<=t[ls].r) return search(p,ls);
        if(p>=t[rs].l) return search(p,rs);
    } 
    int n,m,opt,x,y,k;
    int main(){
        scanf("%d%d",&n,&m);
        for(int i=1;i<=n;i++) scanf("%d",&input[i]);
        build(1,n,1);
        while(m--){
            scanf("%d%d",&opt,&x);
            if(opt==1){
    	    scanf("%d%d",&y,&k);
    	    pls(x,y,k,1);
    	}
    	else printf("%lld\n",search(x,1));
        }
    }
    

    单改 + 区查

    洛谷 P3374

    #include<bits/stdc++.h>
    #define ls rt<<1
    #define rs rt<<1|1 
    using namespace std;
    struct QwQ{int l,r,sum;}t[2000010];
    inline void push_up(int rt){
        t[rt].sum=t[ls].sum+t[rs].sum;
    }
    int input[500002];
    void build(int l,int r,int rt){
        t[rt].l=l,t[rt].r=r;
        if(l==r){
            t[rt].sum=input[l];
    	return;
        }
        int mid=l+r>>1;
        build(l,mid,ls);
        build(mid+1,r,rs);
        push_up(rt);
    }
    void add(int p,int k,int rt){
        t[rt].sum+=k;
        if(t[rt].l==t[rt].r)
            return;
        if(p<=t[ls].r) add(p,k,ls);
        if(p>=t[rs].l) add(p,k,rs);
        push_up(rt);
        return;
    }
    int search(int l,int r,int rt){
        if(t[rt].l>=l&&t[rt].r<=r)
            return t[rt].sum;
        int s=0;
        if(t[ls].r>=l) s+=search(l,r,ls);
        if(t[rs].l<=r) s+=search(l,r,rs);
        return s;
    }
    int n,m,opt,x,y,k;
    int main(){
        scanf("%d%d",&n,&m);
        for(int i=1;i<=n;i++) scanf("%d",&input[i]);
        build(1,n,1);
        while(m--){
            scanf("%d%d%d",&opt,&x,&y);
    	if(opt==1) add(x,y,1);
    	else printf("%d\n",search(x,y,1));
        }
    }
    

    复杂的区间操作

    区间乘法

    例题

    洛谷 P3373

    解析

    多了一个懒标记 mul ,初值为 1
    根据优先级,先乘再加,运算时 mod 不要忘
    更新 mul 时 add 也对应乘一下,保证精度

    代码
    #include<bits/stdc++.h>
    #define ls rt<<1
    #define rs rt<<1|1 
    using namespace std;
    typedef long long ll;
    struct QwQ{
        int l,r;
        ll sum,add,mul;
    }t[2000010];
    int input[500002],mod;
    inline void push_up(int rt){
        t[rt].sum=(t[ls].sum+t[rs].sum)%mod;
    }
    inline void down(int rt){
        t[ls].sum=(t[ls].sum*t[rt].mul+(t[ls].r-t[ls].l+1)*t[rt].add)%mod;
        t[rs].sum=(t[rs].sum*t[rt].mul+(t[rs].r-t[rs].l+1)*t[rt].add)%mod;
        t[ls].mul=(t[ls].mul*t[rt].mul)%mod;
        t[rs].mul=(t[rs].mul*t[rt].mul)%mod;
        t[ls].add=(t[ls].add*t[rt].mul+t[rt].add)%mod;
        t[rs].add=(t[rs].add*t[rt].mul+t[rt].add)%mod;
        t[rt].mul=1,t[rt].add=0;
    }
    void build(int l,int r,int rt){
        t[rt].l=l,t[rt].r=r,t[rt].mul=1;
        if(l==r) t[rt].sum=input[l];
        else{
            int mid=l+r>>1;
    	build(l,mid,ls);
    	build(mid+1,r,rs);
    	push_up(rt);		
        }
        t[rt].sum%=mod;
    }
    void xMul(int l,int r,int k,int rt){
        if(l<=t[rt].l&&r>=t[rt].r){
    	t[rt].sum=(t[rt].sum*k)%mod;
    	t[rt].mul=(t[rt].mul*k)%mod;
    	t[rt].add=(t[rt].add*k)%mod;
    	return;
        }
        down(rt);
        if(l<=t[ls].r) xMul(l,r,k,ls);
        if(r>=t[rs].l) xMul(l,r,k,rs);
        push_up(rt);
    }
    void pls(int l,int r,int k,int rt){
        if(l<=t[rt].l&&r>=t[rt].r){
            t[rt].sum=(t[rt].sum+k*(t[rt].r-t[rt].l+1))%mod;
    	t[rt].add=(t[rt].add+k)%mod;
            return;
        }
        down(rt);
        if(l<=t[ls].r) pls(l,r,k,ls);
        if(r>=t[rs].l) pls(l,r,k,rs);
        push_up(rt);	
    }
    ll query(int l,int r,int rt){
        if(l<=t[rt].l&&r>=t[rt].r)
     		return t[rt].sum;
     	down(rt);
        ll s=0;
        if(l<=t[ls].r) s+=query(l,r,ls);
        if(r>=t[rs].l) s+=query(l,r,rs);
        return(s%mod);
    }
    int n,m,opt,x,y,k;
    int main(){
        scanf("%d%d%d",&n,&m,&mod);
        for(int i=1;i<=n;i++) scanf("%d",&input[i]);
        build(1,n,1);
        while(m--){
    	scanf("%d%d%d",&opt,&x,&y);
    	if(opt==1){
    	    scanf("%d",&k);
    	    xMul(x,y,k,1);
    	}
    	else if(opt==2){
    	    scanf("%d",&k);
    	    pls(x,y,k,1);
    	}
    	else printf("%lld\n",query(x,y,1));
        }
    }
    

    区间开方

    例题

    洛谷 P4145

    解析

    这题的突破口在于: \(\sqrt{1}=1\)
    由于是向下取整,所以最多开方六次就不变了

    我们可以省去懒标记,多加一个 fir 表示区间最大值,区间开方时如果 fir 小于等于 1 就无须继续修改了
    当修改到达叶子节点,把当前节点的 sum 和 fir 都开个方并返回,因为返回之后上一层会 push_up ,达到修改效果

    代码

    这题唯一坑点:左区间会比右区间大,需要交换

    #include<bits/stdc++.h>
    #define ls rt<<1
    #define rs rt<<1|1 
    using namespace std;
    typedef long long ll;
    struct QwQ{
        int l,r;
        ll sum,fir;
    }t[400010];
    inline void push_up(int rt){
        t[rt].sum=t[ls].sum+t[rs].sum;
        t[rt].fir=max(t[ls].fir,t[rs].fir);
    }
    ll input[100005];
    void build(int l,int r,int rt){
        t[rt].l=l,t[rt].r=r;
        if(l==r){
            t[rt].sum=t[rt].fir=input[l];
    	return;
        }
        int mid=l+r>>1;
        build(l,mid,ls);
        build(mid+1,r,rs);
        push_up(rt);
    }
    void xSqrt(int l,int r,int rt){
        if(t[rt].l==t[rt].r){
    	t[rt].sum=sqrt(t[rt].sum);
    	t[rt].fir=sqrt(t[rt].fir);
    	return;
        }
        if(l<=t[ls].r&&t[ls].fir>1) xSqrt(l,r,ls); 
        if(r>=t[rs].l&&t[rs].fir>1) xSqrt(l,r,rs);
        push_up(rt);
    }
    ll search(int l,int r,int rt){
        if(l<=t[rt].l&&r>=t[rt].r)
            return t[rt].sum;
        ll s=0;
        if(l<=t[ls].r) s+=search(l,r,ls);
        if(r>=t[rs].l) s+=search(l,r,rs);
        return s;
    }
    int n,m,opt,x,y;
    int main(){
        scanf("%d",&n);
        for(int i=1;i<=n;i++) scanf("%lld",&input[i]);
        build(1,n,1);
        scanf("%d",&m); 
        while(m--){
            scanf("%d%d%d",&opt,&x,&y);
    	if(x>y) x^=y^=x^=y;
    	if(opt==0) xSqrt(x,y,1);
    	else printf("%lld\n",search(x,y,1));
        }
    }
    


    The End

  • 相关阅读:
    linq to access 简单实现 实例demo
    FCKEDITOR中文使用说明 js调用
    asp.net mvc 随想
    fccms 小型简单个人blog源码
    PHP文件上传路径
    前端优化技巧(一)
    会话框拖拽效果实现
    phpmailer配置
    上传图片动态预览(兼容主流浏览器)
    Java将多个list对象根据属性分组后合并成一个新的集合
  • 原文地址:https://www.cnblogs.com/KonjakLAF/p/12821775.html
Copyright © 2011-2022 走看看