题目链接:http://poj.org/problem?id=3468
一般说来 树状数组是
【单点更新 区间查询】
而本题是“区间更新 区间查询”
所以要改成维护前缀和
推公式的过程见《挑战程序设计竞赛》第181~182面
代码中bit0是i的零次项
bit1是i的一次项
以后以此类推
在[l, r]加数的时候
写出公式
在l的地方 一次项及以上的直接写
然后在零次项那减去
在r的地方 一次项及以上的减掉之前加上的
然后再零次项那加上“公式化简之后的最终结果 减去 之前在零次项加的那一项”
【虽然很乱但我姑且是这么理解的】
以下是树状数组的写法:
#include <iostream> #include <cstdio> #include <cstdlib> #include <cstring> #include <set> #include <map> #include <stack> #include <algorithm> #include <cmath> #include <queue> #include <vector> using namespace std; typedef long long ll; typedef unsigned long long ull; const ull B = 9973; const int maxn = 110000; const ll M = 23333; ll bit0[maxn]; ll bit1[maxn]; int n, q; ll sum(ll *b, int i) { ll s = 0; while(i > 0) { s += b[i]; i -= i & -i; } return s; } void add(ll *b, int i, int v) { while(i <= n) { b[i] += v; i += i & -i; } } int main() { #ifdef LOCAL freopen("input.txt", "r", stdin); //freopen("output.txt", "w", stdout); #endif while(scanf("%d%d", &n, &q) == 2) { for(int i = 1; i <= n; i++) { int t; scanf("%d", &t); add(bit0, i, t); } for(int i = 0; i < q; i++) { char mode; scanf("%C", &mode); scanf("%C", &mode); if(mode == 'C') { int l, r, x; scanf("%d%d%d", &l, &r, &x); add(bit0, l, -x*(l-1)); add(bit1, l, x); add(bit0, r+1, x*r); add(bit1, r+1, -x); } else { int l, r; scanf("%d%d", &l, &r); ll res = 0; res += sum(bit0, r) + sum(bit1, r)*r; res -= sum(bit0, l-1) + sum(bit1, l-1)*(l-1); printf("%lld ", res); } } } return 0; }
然后 再有一个线段树的写法:
#include <iostream> #include <cstdio> #include <cstdlib> #include <cstring> #include <set> #include <map> #include <stack> #include <algorithm> #include <cmath> #include <queue> #include <vector> using namespace std; typedef long long ll; typedef unsigned long long ull; const int maxn = 410000; ll data[maxn], datb[maxn]; int n, q; void add(int a, int b, int x, int k, int l, int r) { if(a <= l && b >= r) { data[k] += x; } else if(l < b && r > a) { datb[k] += (min(b, r) - max(a, l))*x; add(a, b, x, k*2+1, l, (l+r)/2); add(a, b, x, k*2+2, (l+r)/2, r); } } ll sum(int a, int b, int k, int l, int r) { if(b <= l || a >= r) return 0; else if(a <= l && b >= r) { return data[k]*(r-l) + datb[k]; } else { ll res = (min(b, r) - max(a, l))*data[k]; res += sum(a, b, k*2+1, l, (l+r)/2); res += sum(a, b, k*2+2, (l+r)/2, r); return res; } } int main() { #ifdef LOCAL freopen("input.txt", "r", stdin); //freopen("output.txt", "w", stdout); #endif while(scanf("%d%d", &n, &q) == 2) { for(int i = 0; i < n; i++) { int t; scanf("%d", &t); add(i, i+1, t, 0, 0, n); } for(int i = 0; i < q; i++) { char mode; scanf("%c", &mode); scanf("%c", &mode); if(mode == 'C') { int a, b, x; scanf("%d%d%d", &a, &b, &x); add(a-1, b, x, 0, 0, n); } else { int a, b; scanf("%d%d", &a, &b); printf("%lld ", sum(a-1, b, 0, 0, n)); } } } return 0; }