zoukankan      html  css  js  c++  java
  • 树状数组 学习笔记

    1.前置知识

    二叉树。

    分治。

    前缀和。

    2.树状数组

    其实就是前缀和用二叉树做。

    将二叉树右对齐即可。

    如这样一颗二叉树

    将它变成这样

    如下图(绿色为 (C) 数组,红色为 (a) 数组)

    (C_{1}=a_{1})

    (\,\,\,\,\,\,C_{2}=a_{1}+a_{2})

    (\,\,\,\,\,\,C_{3}=a_{3})

    (\,\,\,\,\,\,C_{4}=a_{1}+a_{2}+a_{3}+a_{4})

    (\,\,\,\,\,\,C_{5}=a_{5})

    (\,\,\,\,\,\,C_{6}=a_{5}+a_{6})

    (\,\,\,\,\,\,C_{7}=a_{7})

    (\,\,\,\,\,\,C_{8}=a_{1}+a_{2}+a_{3}+a_{4}+a_{5}+a_{6}+a_{7}+a_{8})

    试试找规律?

    全部转为二进制

    0001	001
    0010	001 010
    0011	011
    0100	001 010 011 100
    0101	101
    0110	101 110
    0111	111
    1000	001 010 011 100 101 110 111
    

    不难发现 (C_{i}) 中数的个数为(2)(i) 的二进制中 (1) 的最右边的位置后的 (0) 的个数 次幂。

    读起来很绕口对吧,举个例子,如 ((0100)_{2}),它的最右边的 (1) 后有 (2)(0)(2^{2}=4),所以 (C_{(0100)_{2}}) 中数的个数为 (4)

    那么问题来了,如何求 (i) 的二进制中最右边的 (1) 的位置呢?

    给出如下代码

    inline int lowbit(int x)
    {
        return x&(-x);
    }
    

    解释一下。

    -x 就是将 (x) 连同符号位一起反转再加一的结果,如 (0010) 的反码为 (1110)

    &运算 不用解释了吧。

    运算x&(-x),举个例子,(0101) 的反码为 (1011),与 (0101) 进行 &运算(0001) ,也就是 (1),这就找到了 (i) 的二进制中最右边的 (1) 的位置。

    3.单点更新,区间查询

    inline void update(int x,int y)//表示将a[x]+y
    {
        for(register int i=x;i<=n;i+=lowbit(i)) a[i]+=y;//每层更新
    }
    

    将每层与 (a_{x}) 相关的值更新一下。

    inline int getsum(int x)//求C[x]的值
    {
        ans=0;
        for(register int i=x;i;i-=lowbit(i)) ans+=a[i];
        return ans;
    }
    

    将每层与 (C_{x}) 相关的值相加求和。

    然后用前缀和做就行啦。

    即区间 ((x,y)) 的值为 getsum(y)-getsum(x-1)

    模板题1

    模板题2

    模板题3

    仅给出 模板1 的代码(其实都差不多)。

    #include<bits/stdc++.h>
    using namespace std;
    int ans;
    int n,m;
    int x,y,z;
    int num;
    int a[500002];
    inline int read()
    {
        int s=0,w=1;
        char ch=getchar();
        while(ch<'0'||ch>'9') {if(ch=='-')w=-1;ch=getchar();}
        while(ch>='0'&&ch<='9') s=s*10+ch-'0',ch=getchar();
        return s*w;
    }
    inline void write(int x)
    {
        if(x<0) putchar('-'),x=-x;
        if(x>9) write(x/10);
        putchar(x%10+'0');
    }
    inline void print(int x)
    {
        write(x);
        putchar('
    ');
    }
    inline int lowbit(int x)
    {
        return x&(-x);
    }
    inline void update(int x,int y)
    {
        for(register int i=x;i<=n;i+=lowbit(i)) a[i]+=y;
    }
    inline int getsum(int x)
    {
        ans=0;
        for(register int i=x;i;i-=lowbit(i)) ans+=a[i];
        return ans;
    }
    int main()
    {
        n=read();m=read();
        for(register int i=1;i<=n;++i)
        {
            z=read();
            update(i,z);
        }
        for(register int i=1;i<=m;++i)
        {
            num=read();x=read();y=read();
            if(num==1) update(x,y);
                else print(getsum(y)-getsum(x-1));
        }
        return 0;
    }
    

    4.区间更新,单点查询

    inline int lowbit(int x)
    {
        return x&(-x);
    }
    inline void update(int x,int y)
    {
        for(register int i=x;i<=n;i+=lowbit(i)) a[i]+=y;
    }
    inline int getsum(int x)
    {
        ans=0;
        for(register int i=x;i;i-=lowbit(i)) ans+=a[i];
        return ans;
    }
    

    这些代码不会变。

    多了个差分。

    差分讲解一下。

    有如下 (a) 数组

    现在要将 ((2,5)) 这个区间里的值都加一。

    直接循环复杂度肯定不优。

    考虑将 (a_{2}+1,a_{5+1}-1)

    即原数组为

    这样在查询时可以定一个 (ans),边循环边加,然后输出。

    a[x]--,a[y+1]++ //差分
    
    for i←1 to n+1
        do s+=a[i] //统计
           write(s,' ') //输出
    

    ( exttt{Q}):为何要这样差分?

    ( exttt{A}):在查询时将值赋为当前正确的值,在查询完减去即可。

    于是可得差分代码

    inline void add(int l,int r,int x)//对(l,r)的区间进行差分
    {
        update(l,x);update(r+1,-x);
    }
    //(应该不难理解吧)
    

    模板题

    直接贴代码。

    #include<bits/stdc++.h>
    using namespace std;
    int ans;
    int n,m;
    int x,y,k;
    int now,last;
    int num;
    int a[500002];
    inline int read()
    {
        int s=0,w=1;
        char ch=getchar();
        while(ch<'0'||ch>'9') {if(ch=='-')w=-1;ch=getchar();}
        while(ch>='0'&&ch<='9') s=s*10+ch-'0',ch=getchar();
        return s*w;
    }
    inline void write(int x)
    {
        if(x<0) putchar('-'),x=-x;
        if(x>9) write(x/10);
        putchar(x%10+'0');
    }
    inline void print(int x)
    {
        write(x);
        putchar('
    ');
    }
    inline int lowbit(int x)
    {
        return x&(-x);
    }
    inline void update(int x,int y)
    {
        for(register int i=x;i<=n;i+=lowbit(i)) a[i]+=y;
    }
    inline int getsum(int x)
    {
        ans=0;
        for(register int i=x;i;i-=lowbit(i)) ans+=a[i];
        return ans;
    }
    inline void add(int l,int r,int x)
    {
        update(l,x);update(r+1,-x);
    }
    int main()
    {
        n=read();m=read();
        for(register int i=1;i<=n;++i)
        {
            now=read();
            update(i,now-last);
            last=now;
        }
        for(register int i=1;i<=m;++i)
        {
            num=read();
            if(num==1)
            {
                x=read();y=read();k=read();
                add(x,y,k);
            }
            else
            {
                x=read();
                print(getsum(x));
            }
        }
        return 0;
    }
    

    5.总结

    参考资料:

    https://www.cnblogs.com/xenny/p/9739600.html

    https://blog.csdn.net/bestsort/article/details/80796531

    https://www.luogu.com.cn/blog/kingxbz/shu-zhuang-shuo-zu-zong-ru-men-dao-ru-fen

    练习:求逆序对。

    #include<bits/stdc++.h>
    #define int long long
    using namespace std;
    struct arr
    {
        int sum,num;
    }A[500002];
    int a[500002];
    int f[500002];
    int n;
    int x;
    int ans;
    inline int read()
    {
        int s=0,w=1;
        char ch=getchar();
        while(ch<'0'||ch>'9') {if(ch=='-')w=-1;ch=getchar();}
        while(ch>='0'&&ch<='9') s=s*10+ch-'0',ch=getchar();
        return s*w;
    }
    inline void write(int x)
    {
        if(x<0) putchar('-'),x=-x;
        if(x>9) write(x/10);
        putchar(x%10+'0');
    }
    inline void print(int x)
    {
        write(x);
        putchar('
    ');
    }
    inline int lowbit(int x)
    {
        return x&(-x);
    }
    inline void update(int x,int y)
    {
        for(int i=x;i<=n;i+=lowbit(i)) f[i]+=y;
    }
    inline int getsum(int x)
    {
        int sum=0;
        for(int i=x;i;i-=lowbit(i)) sum+=f[i];
        return sum;
    }
    bool cmp(arr x,arr y)
    {
        if(x.sum!=y.sum) return x.sum<y.sum;
        return x.num<y.num;
    }
    signed main()
    {
        n=read();
        for(int i=1;i<=n;++i) A[i].sum=read(),A[i].num=i;
        sort(A+1,A+n+1,cmp);
        for(int i=1;i<=n;++i) a[A[i].num]=i;
        for(int i=1;i<=n;++i)
        {
            update(a[i],1);
            ans+=i-getsum(a[i]);
        }
        print(ans);
        return 0;
    }
    
  • 相关阅读:
    RabbitMq(四)远程过程调用RPC
    RabbitMq(三)交换机类型
    RabbitMq(二)工作队列
    java基础知识01--JAVA准备
    匿名子类
    网络之Socket详解
    网络之Socket、TCP/IP、Http关系分析
    Eclipse搭建springboot项目(九)常用Starter和整合模板引擎thymeleaf
    Vue学习——Router传参问题
    sql函数——find_in_set()
  • 原文地址:https://www.cnblogs.com/wuzhenyu/p/14701333.html
Copyright © 2011-2022 走看看