树状数组的应用其实就是一个巧妙地运用了二进制运算来进行 logn 插入、 logn 查询 的 前缀和 算法。
原理分析:
假设有 9 个数字组成的数组:
A[] = 1 2 5 4 3 7 8 6 9
我们使得树状数组 c[] 以以下方法存储:
C[1] = A[1]
C[2] = A[1] + A[2]
C[3] = A[3]
C[4] = A[1] + A[2] + A[3] + A[4]
C[5] = A[5]
C[6] = A[5] + A[6]
C[7] = A[7]
C[8] = A[1] + A[2] + A[3] + A[4] + A[5] + A[6] + A[7] + A[8]
C[9] = A[9]
这样我们是看不出来咋存滴,我们先写成 二进制 再来看看:
C[1] = A[1]
C[10] = A[1] + A[2]
C[11] = A[3]
C[100] = A[1] + A[2] + A[3] + A[4]
C[101] = A[5]
C[110] = A[5] + A[6]
C[111] = A[7]
C[1000] = A[1] + A[2] + A[3] + A[4] + A[5] + A[6] + A[7] + A[8]
C[1001] = A[9]
重点来了,现在我们来构造 C 数组 ,我们观察 A[] 在哪里出现过:
A[1] →→ C[1] 、C[10] 、 C[100] 、C[1000]
A[2] →→ C[10] 、C[100] 、C[1000]
A[3] →→ C[11] 、C[100] 、C[1000]
A[4] →→ C[100] 、C[1000]
A[5] →→ C[110] 、C[1000]
A[6] →→ C[110] 、C[1000]
...... ..........................
好我现在告诉你啥规律~
比如我们来看 A[3] : 3 的二进制是 1 1 ,它先对自己的 C[3(11)] 有贡献,故 C[3] += A[3] 。然后 A[3] 再对 C[4(110)] 有贡献,故 C[4] += A[3] 。同样,它还对 C[8(1000)] 有贡献,即 C[8] += A[3] 。
那 3 是如何一步一步转化成 4 、8 的呢?
3 的二进制是 1 1 1 1 保留它最低位的 1 ,其它位的 1 去掉(即变为 0),则有:0 1 ,然后使 1 1 加上 0 1 → 1 0 0 (即 4)
4 的二进制是 1 0 0 1 0 0 保留它最低位的 1,则有:1 0 0 ,然后使 1 0 0 加上 1 0 0 → 1 0 0 0(即 8)
其他什么,A[1] 、A[2] ......同理,大家可以试着写一下
那么 “ 保留它最低位的 1 ,其它位的 1 去掉(即变为 0)” 这句话所变成的二进制,再转化为十进制数的值,如何求呢?
这里需要一个 自定义的 lowbit 函数,这个函数可以得出上面问题的答案:
ll lowbit(ll x){return x&(-x);}
故我们可以构造出树状数组了~(实现看注释)
inline void update(int x,ll k) // x 表示 c[i] 中的 i ,k 表示 为当前 c[x] 贡献的值 { while(x<=n){ //当然超出 c 数组长度的,就不需要构造。比如 x=3 x=4 x=8 然后再到 x=16 ,16我们就不要存了,因为题目我只需要 C[1] 到 C[9] 的值 c[x]+=k; x+=lowbit(x); // 比如上面的例子A[3],这样写就可以使得 x=3 转化为 x=4 x=8 ,然后 c[4]+= A[3] ,c[8]+= A[3] 了 } } int main() { scanf("%d",&n); for(int i=1;i<=n;i++){ scanf("%lld",&a[i]); update(i,a[i]); } }
那么构造出了 C数组,如何求出前缀和的呢?
比如现在要知道
A[1] + A[2] + A[3] + A[4] + A[5] 的值: C[5] + C[4] , 5 的二进制: 1 0 1 , 4 的二进制: 1 0 0 ,差 0 0 1
A[1] + A[2] + A[3] + A[4] + A[5] + A[6] 的值: C[6] + C[4] , 6 的二进制: 1 1 0 ,4 的二进制: 1 0 0 ,差 0 1 0
那 A[1] + A[2] + ..... + A[9] 的值: C[8] + C[9] , 9 的二进制: 1 0 0 1 , 8的二进制: 1 0 0 0 ,差 0 0 0 1
而现在我们可以观察到:
5 与 4 的差值为 5 - lowbit(5)
6 与 4 的差值为 6 - lowbit(6)
9 与 8 的差值为 9 - lowbit(9)
跟构造 C 数组一样,每次用 ans += C[x] ,然后 x-=lowbit(x) ,直到 x 自己减自己,为 0 时结束,这样我就能得到 1 ~ x 的前缀和了~
ll query(int x) { ll ans=0; while(x){ ans+=c[x]; x-=lowbit(x); } return ans; }
到这里,树状数组的板子就完成了~ 总代码在这:
#include<iostream> #include<algorithm> #include<string.h> #define maxn 500008 using namespace std; typedef long long ll; int n; ll a[maxn],c[maxn]; ll lowbit(ll x){return x&(-x);} inline void update(int x,ll k) { while(x<=n){ c[x]+=k;// 依次把 k 加入到能贡献的 C 中 x+=lowbit(x); // x 往上移 } } ll query(int x) { ll ans=0; while(x){ ans+=c[x]; // 通过累加获得前缀和 x-=lowbit(x); // x 往下移 } return ans; } int main() { scanf("%d",&n); for(int i=1;i<=n;i++){ scanf("%lld",&a[i]); update(i,a[i]); //将 a[i] 所能提供的贡献,加入到 C 数组中 } for(int i=1;i<=n;i++){ // 依次求前缀和 printf("%lld ",query(i) ); } cout<<endl; }
例题1:洛谷 P3374
该题是 树状数组的 单点修改、区间查询的模板题。
由于我们是通过存储每个数,然后把每个数贡献给 C ,然后统计 C 的总和,得出前缀和的。
1、所以当单点修改加上 p ,我们只需要将这个点再次加上 p 就可以了,是不是很方便~
2、现在我们可以知道所有 1 ~ n 的前缀和了,那么区间 l ~ r 总和不就是 1 ~ r 的前缀和 减去 1 ~ l -1 的前缀和吗(注意:这里不是 1 ~ l ,因为前缀和 1 ~ l 的值中包含了 A[l] ,如果减去的是 1 ~ l,那么会减去 A[l] 导致答案不正确)
代码如下:
#include<iostream> #include<algorithm> #include<string.h> #define maxn 500008 using namespace std; typedef long long ll; int n,m; ll a[maxn],c[maxn]; ll lowbit(ll x){return x&(-x);} inline void update(int x,ll k) { while(x<=n){ c[x]+=k; x+=lowbit(x); } } ll query(int x) { ll ans=0; while(x){ ans+=c[x]; x-=lowbit(x); } return ans; } int main() { scanf("%d%d",&n,&m); for(int i=1;i<=n;i++){ scanf("%lld",&a[i]); update(i,a[i]); } int A,B,C; while(m--) { scanf("%d%d%d",&A,&B,&C); if(A==1){ update(B,C); } else{ printf("%lld ",query(C)-query(B-1)); } } }
例题2:洛谷 P3368
该题是 树状数组的 单点修改、区间查询的模板题。需要用到 “差分” 的概念。
看完差分概念,想必你就可以很轻松 A 这题了~
#include<iostream> #include<algorithm> #include<string.h> #define maxn 500008 using namespace std; typedef long long ll; int n,m; ll a[maxn],c[maxn]; ll lowbit(ll x){return x&(-x);} inline void update(int x,ll k) { while(x<=n){ c[x]+=k; x+=lowbit(x); } return; } inline ll query(int x) { ll ans=0; while(x){ ans+=c[x]; x-=lowbit(x); } return ans; } int main() { scanf("%d%d",&n,&m); for(int i=1;i<=n;i++){ scanf("%lld",&a[i]); update(i,a[i]-a[i-1]); } int A,B,C,D; while(m--) { scanf("%d%d",&A,&B); if(A==1){ scanf("%d%d",&C,&D); update(B,D),update(C+1,-D); } else printf("%lld ",query(B)); } }
例题3:洛谷 P1908
该题是 树状数组求 逆序对数 的模板题。
我们先不用树状数组,先用单单的前缀和思想,来求解逆序数。
例如数组 A[] = 1 4 5 3 2
然后我们模拟出一个标记数组 B ,来表A[i] 是否出现,出现则使 B[ A[i] ] 等于 1 。一开始初始化为 0 即可。
1 | 2 | 3 | 4 | 5 |
0 | 0 | 0 | 0 | 0 |
来模拟一遍~
A[1] = 1 ,即 1 出现过,使得 B[1] = 1。 下标 1 表示 A[1] 实际出现位置,A[1] 表示在 1、2、3、4.....中实际位置。1 的前缀和为 1 (包括自己)
1 | 2 | 3 | 4 | 5 |
1 | 0 | 0 | 0 | 0 |
A[2] = 4 ,即 4 出现过,使得 B[4] = 1。下标 2 表示 A[2] 实际出现位置,A[2] 表示在 1、2、3、4.....中实际位置。4 的前缀和为 2 (包括自己)
1 | 2 | 3 | 4 | 5 |
1 | 0 | 0 | 1 | 0 |
A[3] = 5 ,即 5 出现过,使得 B[5] = 1 。下标 3 表示 A[3] 实际出现位置,A[3] 表示在 1、2、3、4.....中实际位置。5 的前缀和为 3 (包括自己)
1 | 2 | 3 | 4 | 5 |
1 | 0 | 0 | 1 | 1 |
A[4] = 3 ,即 3 出现过,使得 B[3] = 1 。下标 4 表示 A[4] 实际出现位置,A[4] 表示在 1、2、3、4.....中实际位置。3 的前缀和为 2 (包括自己)
1 | 2 | 3 | 4 | 5 |
1 | 0 | 1 | 1 | 1 |
诶!重点的来了!你看,下标为 4 ,但是此时 A[4] = 3 ,它的前缀和是 2 ,小于了 4 !这里说明有 4 - 2 个 逆序对,即 (5,3) 和 (4,3)!
原因:
我们是使 A[i] 出现过后,标记它出现过。i 为A[i] 的真正位置,而 A[i] 才为它的实际位置。
现在 i 都到 4 了,说明除了 3 自己之外,应该还有 3 个数已经被标记,由于都是以 1 作标记,所以此时 A[i] = 3 它前面仅有 1 个数 1 才被标记过,说明有 2 个数不在 3 的前面呀,那就是说明这两个数都比 3 大,而且在 3 之前就出现过~
那么很容易的,我们得到以 (P,3) 为逆序对的个数为: i - S(3) (S(3)表示 3 的前缀和, P > 3)
你可以理解成: i 为已经被标记过的总数,由于都是用 1 标记,所以S(3)表示有多少个小于 3 的数出现过,那么 i - S(3) 就是那些出现过的但大于 3 的数了~
嗯求前缀和,用树状数组!
但是这题还需要用到离散化。
那么用到离散化求逆序数的话,就有一个很大的问题!
如果用 sort 排序,由于 sort 为不稳定排序,会导致相同的数,位置发生调换。
比如 A[] = 1 4 4 2 ,很明显,离散化后的顺序应该是: 1 3 4 2 ,而不能是 1 4 3 2 。否则会多出一个 (4,3)的逆序对,导致答案不符。
所以我们只需要更早出现的,sort 更早标记就行了~
代码如下:
#include<iostream> #include<algorithm> #include<string.h> #define maxn 500008 using namespace std; typedef long long ll; int n; int a[maxn]; ll c[maxn]; struct Node { int id; ll val; }A[maxn]; inline int lowbit(int x){return x&(-x);} bool cmp(const Node q,const Node w){ if(q.val==w.val) return q.id<w.id; return q.val<w.val; } inline void update(int x,int k) { while(x<=n){ c[x]+=k; x+=lowbit(x); } return; } inline ll query(int x) { ll ans=0; while(x){ ans+=c[x]; x-=lowbit(x); } return ans; } int main() { scanf("%d",&n); for(int i=1;i<=n;i++){ scanf("%lld",&A[i].val); A[i].id=i; } sort(A+1,A+n+1,cmp); for(int i=1;i<=n;i++) a[A[i].id]=i; ll ans = 0; for(int i=1;i<=n;i++){ update(a[i],1); ans+=i-query(a[i]); } printf("%lld ",ans ); }