以前一直以为树状数组只能处理单点查询和区间修改、单点修改和区间查询的问题。最近看到两个代码,发现自己太out了。它还能处理区间查询和区间修改(加减一个数)的问题。我们先来看一维的情况。
对$[s,t]$区间增加$x$,利用差分的思想,只需要让$delta[s]+x,delta[t+1]-x$即可。其中$delta$表示差分数组。
查询区间$[s,t]$的和,可以通过区间$[1,t]$-$[1,s-1]$得来。而$sum(1,t)=sum_{i=1}^t delta[i]*(t-i+1)$,$sum(1,s-1)=sum_{i=1}^{s-1}delta[i]*(s-i)$
于是$sum[s,t]=sum[1,t]-sum[1,s-1]=sum_{i=1}^{t}delta[i]*(t+1)-sum_{i=1}^{s-1}delta[i]*s-sum_{i=s}^t(delta[i]*i)$
其中$num'$表示修改过后的数组,$num$表示原数组,$delta$表示差分数组。
我们用两个树状数组,一个维护$delta[i]$,一个维护$delta[i]*i$,那么上面的区间修改和区间查询均可以在$O(logN)$的时间内完成了。
再看两维的情况:
对区间$[x1,y1,x2,y2]$增加x,同样利用差分的思想,只需要让$delta[x1][y1]+x,delta[x1,y2+1]-x,delta[x2,y1+1]-x,delta[x2+1][y2+1]+x$即可。同样,$delta[x][y]$表示从点$[x,y]$到点$[n,m]$的区域的增量。
定义$[x1,y1,x2,y2]$表示左下角顶点为$[x1,y1]$,右上角顶点为$[x2,y2]$的区域,该区域的数之和用$sum(x1,y1,x2,y2)$表示。可得$sum(x1,y1,x2,y2)=sum(1,1,x2,y2)-sum(1,1,x2,y1-1)-sum(1,1,x1-1,y2)+sum(1,1,x1-1,y1-1)$.
如何求$sum(1,1,x2,y2)$呢?
$sum(1,1,x2,y2)=sumlimits_{i=1}^{x2}sumlimits_{j=1}^{y2}(sumlimits_{i'=1}^{i} sumlimits_{j'=1}^j delta[i'][j'])=sumlimits_{i=1}^{x2} sumlimits_{j=1}^{y2} delta[i][j]*(x2-i+1)*(y2-j+1)$
对$num$求和可以使用前缀和预处理,而对delta数组求和
$sumlimits_{i=1}^{x2} sumlimits_{j=1}^{y2} delta[i][j]*(x2-i+1)*(y2-j+1)=sumlimits_{i=1}^{x2} sumlimits_{j=1}^{y2} delta[i][j]*((x2+1)(y2+1)-(x2+1)*j-(y2+1)*i+i*j)\=(x2+1)*(y2+1)*sumlimits_{i=1}^{x2}sumlimits_{j=1}^{y2}delta[i][j]\-(x2+1)*sumlimits_{i=1}^{x2} sumlimits_{j=1}^{y2}(j*delta[i][j])\-(y2+1)*sumlimits_{i=1}^{x2} sumlimits_{j=1}^{y2}(i*delta[i][j])\+sumlimits_{i=1}^{x2} sumlimits_{j=1}^{y2}delta[i][j]*i*j$
上式最终可以分为四个树状数组求和,四个树状数组分别维护$delta[i][j]$,$delta[i][j]*i$,$delta[i][j]*j$,$delta[i][j]*i*j$.
下面是用二维树状数组完成二维区间查询和二维区间修改的操作。
#include <iostream> #include<cstdio> #include<cstring> #include<cstdlib> using namespace std; #define MAXN 2055 int x1,y1,x2,y2,n,m,val; char opt[3]; int tree[5][MAXN][MAXN]; #define lowbit(x) (x&-(x)) void update(int w,int posx,int posy,int val) { for(int i=posx;i<=n;i+=lowbit(i)) { for(int j=posy;j<=m;j+=lowbit(j)) { tree[w][i][j]+=val; } } } int getsum(int w,int posx,int posy) { int sum=0; for(int i=posx;i;i-=lowbit(i)) { for(int j=posy;j;j-=lowbit(j)) { sum+=tree[w][i][j]; } } return sum; } void up(int a,int b,int val) { update(0,a,b,val); update(1,a,b,b*val); update(2,a,b,a*val); update(3,a,b,a*b*val); } int sum(int a,int b) { return (a+1)*(b+1)*getsum(0,a,b)-(b+1)*getsum(2,a,b)-(a+1)*getsum(1,a,b)+getsum(3,a,b); } int main() { scanf("%s%d%d",opt,&n,&m); while(scanf("%s",opt)!=-1) { if(opt[0]=='L') { scanf("%d%d%d%d%d",&x1,&y1,&x2,&y2,&val); up(x1,y1,val); up(x1,y2+1,-val); up(x2+1,y1,-val); up(x2+1,y2+1,val); } else { scanf("%d%d%d%d",&x1,&y1,&x2,&y2); printf("%d",sum(x2,y2)-sum(x2,y1-1)-sum(x1-1,y2)+sum(x1-1,y1-1)); } } printf(" "); for(int i=1;i<=n;i++) for(int j=1;j<=m;j++) { printf("%d ",tree[0][i][j]); if(j==m)printf(" "); else printf(" "); } printf(" "); return 0; }