zoukankan      html  css  js  c++  java
  • 一维二维树状数组写法总结

    (有任何问题欢迎留言或私聊 && 欢迎交流讨论哦

    @




    一维习题:hdu1541 bzoj3211(hdu4027)
    二维习题:hdu2642 1892 5517

    一维树状数组:

    struct FenwickTree {
        int BIT[MXN];
        int lowbit(int x) {return x&(-x);}
        void add_bit(int x, int val, int N) {for(;x <= N; x += lowbit(x)) BIT[x] += val;}
        int query_bit(int x) {int ans = 0; for(; x; x -= lowbit(x)) ans += BIT[x]; return ans;}
    }bit;
    

    改点求段:

    void add(int x,int v){
      while(x <= n){
        ar[x] += v;
        x += lowbit(x);
      }
    }
    int query(int x){
      int sum = 0;
      while(x > 0){
        sum += ar[x];
        x -= lowbit(x);
      }
      return sum;
    }
    int range(int l, int r){
    	return query(r) - query(l-1);
    }
    

    改段求点:

    void add(int x,int v){
      while(x <= n){
        delta[x] += v;
        x += lowbit(x);
      }
    }
    int query(int x){
      int sum = 0;
      while(x > 0){
        sum += delta[x];
        x -= lowbit(x);
      }
      return sum;
    }
    void init(){
    	for(int i=1;i<=n;++i){
          scanf("%d", &ar[i]);
          add(i, ar[i]-ar[i-1]);
        }
    }
    int get_pos(int x){
    	return query(x);
    }
    

    改段求段:

    here

    //sum[i] = sigma(ar[x])+(i+1)*sigma(delta[x])-sigma(x*delta[x])
    //delta[]是差分数组
    void add(LL *a, int x, LL v){
      while(x <= n){
        a[x] += v;
        x += lowbit(x);
      }
    }
    LL query(LL *a, int x){
      LL sum = 0;
      while(x > 0){
        sum += a[x];
        x -= lowbit(x);
      }
      return sum;
    }
    void init(){
    	pre[0] = 0;
        for(int i = 1; i <= n; ++i){
          scanf("%lld", &ar[i]);
          pre[i] = pre[i-1] + ar[i];
        }
    }
    void update(int l, int r, LL x){
    	add(delta, l, x);add(delta, r+1, -x);
        add(deltai, l, l*x);add(deltai, r+1, -x*(r+1));
    }
    LL range(int l, int r){
    	LL sum1 = pre[l-1]+l*query(delta, l-1)-query(deltai, l-1);
        LL sum2 = pre[r]+(r+1)*query(delta, r)-query(deltai, r);
        return sum2-sum1;
    }
    

    二维树状数组:

    改点求段:

    void add(int x, int y, int z){
      int tmp = y;
      while(x<=n){
        y = tmp;
        while(y<=n){
          cw[x][y] += z, y += lowbit(y);
        }
        x += lowbit(x);
      }
    }
    int query(int x, int y){
      int res = 0, tmp = y;
      while(x){
        y = tmp;
        while(y){
          res += cw[x][y], y -= lowbit(y);
        }
        x -= lowbit(x);
      }
      return res;
    }
    

    改段求点:

    //d[i][j]表示 a[i][j]与a[i−1][j]+a[i][j−1]−a[i−1][j−1]的差
    //delta[][]是差分数组
    void add(int x, int y, int z){
      int tmp = y;
      while(x <= n){
        y = tmp;
        while(y <= n){
          delta[x][y] += z, y += lowbit(y);
        }
        x += lowbit(x);
      }
    }
    void update(int xa,int ya,int xb,int yb,int z){
      add(xa,ya,z);add(xa,yb+1,-z);add(xb+1,ya,-z);add(xb+1,yb+1,z);
    }
    int query(int x, int y){
      int res = 0, tmp = y;
      while(x){
        y = tmp;
        while(y){
          res += delta[x][y], y -= lowbit(y);
        }
        x -= lowbit(x);
      }
      return res;
    }
    void init(){
    	for(int i = 1; i <= n; ++i){
          for(int j = 1; j <= n; ++j){
            int tmp = ar[i][j]-ar[i-1][j]-ar[i][j-1]+ar[i-1][j-1];
            add(i,j,tmp);
          }
        }
    }
    

    改段求段:

    sum[x][y] = (x+1)(y+1) (Sigma) (d[i][j]) - (y+1)(Sigma)(id[i][j]) - (x+1)(Sigma)(jd[i][j]) + (Sigma)(ij*d[i][j])

    //sum[x][y] = (x+1)(y+1)sigma(d[i][j])-(y+1)sigma(i*d[i][j])-(x+1)sigma(j*d[i][j])+sigma(i*j*d[i][j])
    
    #include<cstdio>
    #include<cstring>
    #include<algorithm>
    #include<assert.h>
    #include<bitset>
    #define lson rt<<1
    #define rson rt<<1|1
    #define lowbit(x) (x)&(-(x))
    #define all(x) (x).begin(),(x).end()
    using namespace std;
    typedef long long LL;
    const int INF = 0x3f3f3f3f;
    const int N = (int)1e3 +107;
    int ar[N][N], da[N][N], di[N][N], dj[N][N],dij[N][N];
    int n, m, q;
    //sumxy = (x+1)(y+1)sigma(d[i][j])-(y+1)sigma(i*d[i][j])-(x+1)sigma(j*d[i][j])+sigma(i*j*d[i][j])
    void add(int x, int y, int z){
      for(int i=x;i<=n;i+=lowbit(i)){
        for(int j=y;j<=n;j+=lowbit(j)){
          da[i][j] += z; di[i][j] += z*x; dj[i][j] += z*y; dij[i][j] += z*x*y;
        }
      }
    }
    void update(int xa,int ya,int xb,int yb,int z){
      add(xa,ya,z);add(xa,yb+1,-z);add(xb+1,ya,-z);add(xb+1,yb+1,z);
    }
    int query(int x, int y){
      int res = 0;
      for(int i = x; i>0; i -= lowbit(i)){
        for(int j = y; j>0; j -= lowbit(j)){
          res += (x+1)*(y+1)*da[i][j] - (y+1)*di[i][j] - (x+1)*dj[i][j] + dij[i][j];
        }
      }
      return res;
    }
    int ask(int xa,int ya,int xb,int yb){
      return query(xb,yb)-query(xb,ya-1)-query(xa-1,yb)+query(xa-1,ya-1);
    }
    void init(){
      for(int i = 1; i <= n; ++i){
        for(int j = 1; j <= n; ++j){
          int tmp = ar[i][j]-ar[i-1][j]-ar[i][j-1]+ar[i-1][j-1];
          add(i,j,tmp);
          //update(i,j,i,j,ar[i][j]);
        }
      }
    }
    int main(){
      while(~scanf("%d", &n)){
        memset(ar,0,sizeof(ar));
        for(int i=1;i<=n;++i){
          for(int j=1;j<=n;++j){
            scanf("%d",&ar[i][j]);
          }
        }
        init();
        scanf("%d",&q);
        while(q--){
          int op,xa,xb,ya,yb,c;
          scanf("%d%d%d%d%d",&op,&xa,&ya,&xb,&yb);
          if(op==1){
            scanf("%d",&c);
            update(xa,ya,xb,yb,c);
          }else{
            printf("%d
    ", ask(xa,ya,xb,yb));
          }
          if(q<=0)break;
        }
      }
      return 0;
    }
    

    习题答案:

    HDU1892

    #include<cstdio>
    #include<cstring>
    #include<algorithm>
    #include<assert.h>
    #include<bitset>
    #define lson rt<<1
    #define rson rt<<1|1
    #define lowbit(x) (x)&(-(x))
    #define all(x) (x).begin(),(x).end()
    using namespace std;
    typedef long long LL;
    const int INF = 0x3f3f3f3f;
    const int N = (int)1e3 +107;
    int ar[N][N], da[N][N], di[N][N], dj[N][N],dij[N][N];
    int n, m, q;
    //sumxy = (x+1)(y+1)sigma(d[i][j])-(y+1)sigma(i*d[i][j])-(x+1)sigma(j*d[i][j])+sigma(i*j*d[i][j])
    void add(int x,int y,int c){
      for(int i=x;i<=n;i+=lowbit(i)){
        for(int j=y;j<=n;j+=lowbit(j)){
          da[i][j]+=c;
        }
      }
    }
    int query(int x,int y){
      int sum=0;
      for(int i=x;i;i-=lowbit(i)){
        for(int j=y;j;j-=lowbit(j)){
          sum+=da[i][j];
        }
      }
      return sum;
    }
    int ask(int xa,int ya,int xb,int yb){
      return query(xb,yb)-query(xa-1,yb)-query(xb,ya-1)+query(xa-1,ya-1);
    }
    int main(){
      int tim;
      int tc=0;
      scanf("%d",&tim);
      while(tim--){
        n=1002;
        memset(da,0,sizeof(da));
        for(int i=1;i<=n;++i){
          for(int j=1;j<=n;++j){
            add(i,j,1);
          }
        }
        printf("Case %d:
    ", ++tc);
        scanf("%d",&q);
        while(q--){
          char op[2];
          int xa=0,xb=0,ya=0,yb=0,c;
          scanf("%s",op);
          if(op[0]=='S'){
            scanf("%d%d%d%d",&xa,&ya,&xb,&yb);
            xa++;ya++;xb++;yb++;
            if(xa>xb)swap(xa,xb);
            if(ya>yb)swap(ya,yb);
            printf("%d
    ", ask(xa,ya,xb,yb));
          }else if(op[0]=='A'){
            scanf("%d%d%d",&xa,&ya,&c);
            xa++;ya++;xb++;yb++;
            add(xa,ya,c);
          }else if(op[0]=='D'){
            scanf("%d%d%d",&xa,&ya,&c);
            xa++;ya++;xb++;yb++;
            c=min(c,ask(xa,ya,xa,ya));
            add(xa,ya,-c);
          }else{
            scanf("%d%d%d%d%d",&xa,&ya,&xb,&yb,&c);
            xa++;ya++;xb++;yb++;
            c=min(c,ask(xa,ya,xa,ya));
            add(xa,ya,-c);
            add(xb,yb,c);
          }
        }
      }
      return 0;
    }
    

    HDU2642:

    #include<cstdio>
    #include<cstring>
    #include<algorithm>
    #include<assert.h>
    #include<bitset>
    #define lson rt<<1
    #define rson rt<<1|1
    #define lowbit(x) (x)&(-(x))
    #define all(x) (x).begin(),(x).end()
    using namespace std;
    typedef long long LL;
    const int INF = 0x3f3f3f3f;
    const int N = (int)1e3 +107;
    int ar[N][N], da[N][N], di[N][N], dj[N][N],dij[N][N];
    int n, m, q;
    //sumxy = (x+1)(y+1)sigma(d[i][j])-(y+1)sigma(i*d[i][j])-(x+1)sigma(j*d[i][j])+sigma(i*j*d[i][j])
    void add(int x, int y, int z){
      for(int i=x;i<=n;i+=lowbit(i)){
        for(int j=y;j<=n;j+=lowbit(j)){
          da[i][j] += z;
        }
      }
    }
    int query(int x, int y){
      int res = 0;
      for(int i = x; i>0; i -= lowbit(i)){
        for(int j = y; j>0; j -= lowbit(j)){
          res += da[i][j];
        }
      }
      return res;
    }
    int ask(int xa,int ya,int xb,int yb){
      return query(xb,yb)-query(xb,ya-1)-query(xa-1,yb)+query(xa-1,ya-1);
    }
    int main(){
      while(~scanf("%d", &q)){
        n=1001;
        memset(da,0,sizeof(da));
        memset(ar,0,sizeof(ar));
        while(q--){
          char op[2];
          int xa,xb,ya,yb,c;
          scanf("%s",op);
          if(op[0]=='B'){
            scanf("%d%d",&xa,&ya);
            xa++;ya++;
            if(ar[xa][ya]==0)add(xa,ya,1);
            ar[xa][ya]=1;
          }else if(op[0]=='D'){
            scanf("%d%d",&xa,&ya);
            xa++;ya++;
            if(ar[xa][ya]==1)add(xa,ya,-1);
            ar[xa][ya]=0;
          }else{
            scanf("%d%d",&xa,&xb);
            scanf("%d%d",&ya,&yb);
            xa++;ya++;
            xb++;yb++;
            if(xa>xb)swap(xa,xb);
            if(ya>yb)swap(ya,yb);
            printf("%d
    ", ask(xa,ya,xb,yb));
          }
        }
      }
      return 0;
    }
    

    hdu1541

    #include<iostream>
    #include<algorithm>
    #include<cstring>
    #include<cstdio>
    using namespace std;
    const int N=1e5+10;
    int a[N],sum[N];
    int lowbit(int x){return x&(-x);}
    int Sum(int n){
        int sum=0;
        while(n>0){
            sum+=a[n];
            n-=lowbit(n);
        }
        return sum;
    }
    void add(int x){
        while(x<=N){
            ++a[x];
            x+=lowbit(x);
        }
    }
    int main() {
        int x,y,n;
        while(~scanf("%d",&n)){
            memset(a,0,sizeof(a));
            memset(sum,0,sizeof(sum));
            for(int i=0;i<n;++i){
                scanf("%d %d",&x,&y);
                //1 5 7 3 5
                //2 6 8 4 6
                //0 1 2 1 3
                x++;
                sum[Sum(x)]++;
                add(x);
            }
            for(int i=0;i<n;++i){
                printf("%d
    ",sum[i]);
            }
        }
        return 0;
    }
    

    具体请访问这个博客:here

    树状数组求区间最值

    (bit[x]) 是区间([x-lowbit(x)+1, x])的最值
    能转移到(x)的状态是 (x-2^0, x-2^1 ... x-2^k)(2^k < lowbit(x))
    (y - lowbit(y) >= x),则(query(x,y) = max(bit[y], query(x, y-lowbit(y))))
    (y - lowbit(y) < x),则(query(x,y) = max(ar[y], query(x, y-1)))

    #include<iostream>
    #include<cstdio>
    #include<assert.h>
    #include<ctime>
    #include<algorithm>
    #include<cstring>
    //#include<bits/stdc++.h>
    #define lowbit(x) (x)&(-(x))
    #define all(x) x.begin(),x.end()
    #define iis std::ios::sync_with_stdio(false)
    #define mme(a,b) memset((a),(b),sizeof((a)))
    using namespace std;
    typedef long long LL;
    const int MXN = 2e5+7;
    const int INF = 0x3f3f3f3f;
    const int MOD = 1e9 + 7;
    int n, m;
    int ar[MXN], bit[MXN];
    
    int lowbit(int x){return x & (-x);}
    inline void mymax(int &a,int b){a = a > b? a: b;}
    void update(int x){
      while(x <= n){
        bit[x] = ar[x];
        int tmp = lowbit(x);
        for(int i = 1; i < tmp; i <<= 1){
          //bit[x] = max(bit[x], bit[x-i]);
          mymax(bit[x], bit[x-i]);
        }
        x += lowbit(x);
      }
    }
    int query(int x,int y){
      int ans = 0;
      while(y >= x){
        //ans = max(ans, ar[y]);
        mymax(ans, ar[y]);
        --y;
        for( ; y - lowbit(y) >= x; y -= lowbit(y)){
          //ans = max(ans, bit[y]);
          mymax(ans, bit[y]);
        }
      }
      return ans;
    }
    int main(){
      char op;
      int a, b;
      while(scanf("%d%d", &n, &m)!=EOF){
        for(register int i = 0; i <= n; ++i)bit[i] = 0;
        for(register int i = 1; i <= n; ++i){
          scanf("%d", &ar[i]);
          update(i);
        }
        while(m--){
          scanf("%c", &op);
          scanf("%c", &op);
          scanf("%d%d", &a, &b);
          if(op == 'U') {
            ar[a] = b;
            update(a);
          }else {
            a = query(a, b);
            printf("%d
    ", a);
          }
        }
      }
      return 0;
    }
    
    
  • 相关阅读:
    BestCoder Round #29 1003 (hdu 5172) GTY's gay friends [线段树 判不同 预处理 好题]
    POJ 1182 食物链 [并查集 带权并查集 开拓思路]
    Codeforces Round #288 (Div. 2) E. Arthur and Brackets [dp 贪心]
    Codeforces Round #287 (Div. 2) E. Breaking Good [Dijkstra 最短路 优先队列]
    Codeforces Round #287 (Div. 2) D. The Maths Lecture [数位dp]
    NOJ1203 最多约数问题 [搜索 数论]
    poj1426
    POJ 1502 MPI Maelstrom [最短路 Dijkstra]
    POJ 2785 4 Values whose Sum is 0 [二分]
    浅析group by,having count()
  • 原文地址:https://www.cnblogs.com/Cwolf9/p/9513252.html
Copyright © 2011-2022 走看看