问题:一个由数字构成的大矩阵,能进行两种操作
1) 对矩阵里的某个数加上一个整数(可正可负)
2) 查询某个子矩阵里所有数字的和,要求对每次查询,输出结果。
一维树状数组很容易扩展到二维,在二维情况下:数组A[][]的树状数组定义为:
C[x][y] = ∑ a[i][j], 其中,
x-lowbit(x) + 1 <= i <= x,
y-lowbit(y) + 1 <= j <= y.
例:举个例子来看看C[][]的组成。
设原始二维数组为:
A[][]={{a11,a12,a13,a14,a15,a16,a17,a18,a19},
{a21,a22,a23,a24,a25,a26,a27,a28,a29},
{a31,a32,a33,a34,a35,a36,a37,a38,a39},
{a41,a42,a43,a44,a45,a46,a47,a48,a49}};
那么它对应的二维树状数组C[][]呢?
记:
B[1]={a11,a11+a12,a13,a11+a12+a13+a14,a15,a15+a16,…} 这是第一行的一维树状数组
B[2]={a21,a21+a22,a23,a21+a22+a23+a24,a25,a25+a26,…} 这是第二行的一维树状数组
B[3]={a31,a31+a32,a33,a31+a32+a33+a34,a35,a35+a36,…} 这是第三行的一维树状数组
B[4]={a41,a41+a42,a43,a41+a42+a43+a44,a45,a45+a46,…} 这是第四行的一维树状数组
那么:
C[1][1]=a11,C[1][2]=a11+a12,C[1][3]=a13,C[1][4]=a11+a12+a13+a14,c[1][5]=a15,C[1][6]=a15+a16,…
这是A[][]第一行的一维树状数组
C[2][1]=a11+a21,C[2][2]=a11+a12+a21+a22,C[2][3]=a13+a23,C[2][4]=a11+a12+a13+a14+a21+a22+a23+a24,
C[2][5]=a15+a25,C[2][6]=a15+a16+a25+a26,…
这是A[][]数组第一行与第二行相加后的树状数组
C[3][1]=a31,C[3][2]=a31+a32,C[3][3]=a33,C[3][4]=a31+a32+a33+a34,C[3][5]=a35,C[3][6]=a35+a36,…
这是A[][]第三行的一维树状数组
C[4][1]=a11+a21+a31+a41,C[4][2]=a11+a12+a21+a22+a31+a32+a41+a42,C[4][3]=a13+a23+a33+a43,…
这是A[][]数组第一行+第二行+第三行+第四行后的树状数组
搞清楚了二维树状数组C[][]的规律了吗? 仔细研究一下,会发现:
(1)在二维情况下,如果修改了A[i][j]=delta,则对应的二维树状数组更新函数为:
1 void Modify(int i, int j, int delta) 2 { 3 4 A[i][j]+=delta; 5 6 for(int x = i; x< A.length; x += lowbit(x)) 7 for(int y = j; y <A[i].length; y += lowbit(y)) 8 { 9 C[x][y] += delta; 10 } 11 }
(2)在二维情况下,求子矩阵元素之和∑ a[i]j的函数为
1 int Sum(int i, int j) 2 { 3 int result = 0; 4 for(int x = i; x > 0; x -= lowbit(x)) 5 { 6 for(int y = j; y > 0; y -= lowbit(y)) 7 { 8 result += C[x][y]; 9 } 10 } 11 return result; 12 }
比如:
Sun(1,1)=C[1][1]; Sun(1,2)=C[1][2]; Sun(1,3)=C[1][3]+C[1][2];…
Sun(2,1)=C[2][1]; Sun(2,2)=C[2][2]; Sun(2,3)=C[2][3]+C[2][2];…
Sun(3,1)=C[3][1]+C[2][1]; Sun(3,2)=C[3][2]+C[2][2];
1 #include <cstdio> 2 #include <cstring> 3 typedef long long LL; 4 5 const int N = 1100; 6 7 int t, n; 8 LL bit[N][N]; 9 10 inline int lowbit(int x) { 11 return x & (-x); 12 } 13 14 LL Query(int x, int y) { 15 LL ans = 0; 16 for (int i = x; i > 0 ; i -= lowbit(i)) 17 for (int j = y; j > 0; j -= lowbit(j)) 18 ans += bit[i][j]; 19 return ans; 20 } 21 22 void Modify(int x, int y, int c) { 23 for (int i = x; i < N; i += lowbit(i)) 24 for (int j = y; j < N; j += lowbit(j)) 25 bit[i][j] += c; 26 }
区间修改 区间查询
类比之前一维数组的区间修改区间查询,下面这个式子表示的是点(x, y)的二维前缀和:
这个式子炒鸡复杂( O(n4)O(n4) 复杂度!),但利用树状数组,我们可以把它优化到 O(log2n)O(log2n)!
首先,类比一维数组,统计一下每个d[h][k]d[h][k]出现过多少次。d[1][1]d[1][1]出现了x∗yx∗y次,d[1][2]d[1][2]出现了x∗(y−1)x∗(y−1)次……d[h][k]d[h][k] 出现了 (x−h+1)∗(y−k+1)(x−h+1)∗(y−k+1) 次。
那么这个式子就可以写成:
把这个式子展开,就得到:
那么我们要开四个树状数组,分别维护:
d[i][j],d[i][j]∗i,d[i][j]∗j,d[i][j]∗i∗j
1 #include <cstdio> 2 #include <cmath> 3 #include <cstring> 4 #include <algorithm> 5 #include <iostream> 6 using namespace std; 7 typedef long long ll; 8 ll read(){ 9 char c; bool op = 0; 10 while((c = getchar()) < '0' || c > '9') 11 if(c == '-') op = 1; 12 ll res = c - '0'; 13 while((c = getchar()) >= '0' && c <= '9') 14 res = res * 10 + c - '0'; 15 return op ? -res : res; 16 } 17 const int N = 205; 18 ll n, m, Q; 19 ll t1[N][N], t2[N][N], t3[N][N], t4[N][N]; 20 void add(ll x, ll y, ll z){ 21 for(int X = x; X <= n; X += X & -X) 22 for(int Y = y; Y <= m; Y += Y & -Y){ 23 t1[X][Y] += z; 24 t2[X][Y] += z * x; 25 t3[X][Y] += z * y; 26 t4[X][Y] += z * x * y; 27 } 28 } 29 void range_add(ll xa, ll ya, ll xb, ll yb, ll z){ //(xa, ya) 到 (xb, yb) 的矩形 30 add(xa, ya, z); 31 add(xa, yb + 1, -z); 32 add(xb + 1, ya, -z); 33 add(xb + 1, yb + 1, z); 34 } 35 ll ask(ll x, ll y){ 36 ll res = 0; 37 for(int i = x; i; i -= i & -i) 38 for(int j = y; j; j -= j & -j) 39 res += (x + 1) * (y + 1) * t1[i][j] 40 - (y + 1) * t2[i][j] 41 - (x + 1) * t3[i][j] 42 + t4[i][j]; 43 return res; 44 } 45 ll range_ask(ll xa, ll ya, ll xb, ll yb){ 46 return ask(xb, yb) - ask(xb, ya - 1) - ask(xa - 1, yb) + ask(xa - 1, ya - 1); 47 } 48 int main(){ 49 n = read(), m = read(), Q = read(); 50 for(int i = 1; i <= n; i++){ 51 for(int j = 1; j <= m; j++){ 52 ll z = read(); 53 range_add(i, j, i, j, z); 54 } 55 } 56 while(Q--){ 57 ll ya = read(), xa = read(), yb = read(), xb = read(), z = read(), a = read(); 58 if(range_ask(xa, ya, xb, yb) < z * (xb - xa + 1) * (yb - ya + 1)) 59 range_add(xa, ya, xb, yb, a); 60 } 61 for(int i = 1; i <= n; i++){ 62 for(int j = 1; j <= m; j++) 63 printf("%lld ", range_ask(i, j, i, j)); 64 putchar(' '); 65 } 66 return 0; 67 }