一、前言
对线段树还挺熟悉的我之前却从没写过分块的题(?)。然后兴致一来就决定搜点练习,找到了来自hzwer的《数列分块入门1-9》,觉得挺不错的,于是决定做做。
二、概念 / 作用
概念:将数列等分为若干个不相交的区间,每一个区间称为一个块。
作用:优化算法,降低复杂度。具体如何降低,在下面的题目中会逐步提及。题目呈难度递增趋势。
三、题目 / 代码
1、分块入门1(传送门:https://loj.ac/problem/6277)
题面:给出一个长为 n 的数列,以及 n 个操作,操作涉及区间加法,单点查询。
挺多数据结构均能实现的经典题目,譬如线段树。这里我们用分块来做。将 n 个元素等分为若干块,比如{1, 4, 8, 2, 9, 6, 3, 7, 5},等分为3块,则第一块包含的数据为{1, 4, 8},第二、三块以此类推。我们给每一个块增加一个加法标记,对于每次的区间[l, r]加法操作,直接对块进行标记叠加。
l, r必然不一定是块的边界,也就意味着左右端点可能在块的中间,直接一个个暴力增加。设块的元素个数为m,标记块个数至多n / m个,暴力增加元素个数至多2m个,复杂度分析:O(n / m) + O(m),根据均值不等式,可证明m = √n时存在最低复杂度。故我们以下所有分块大小均默认为√n。
询问就很轻松了,直接返回元素的值加上所在区间的标记。
代码:
1 #include <bits/stdc++.h> 2 using namespace std; 3 4 #define MAXN 50005 5 6 int n, a[MAXN], x, b[MAXN], f[MAXN]; 7 8 void add(int l, int r, int w) { 9 for (int i = l; i <= min(b[l] * x, r); i++) a[i] += w; 10 if (b[l] != b[r]) 11 for (int i = (b[r] - 1) * x + 1; i <= r; i++) a[i] += w; 12 for (int i = b[l] + 1; i <= b[r] - 1; i++) f[i] += w; 13 } 14 15 int main() { 16 int o, l, r, w; 17 scanf("%d", &n), x = sqrt(n); 18 for (int i = 1; i <= n; i++) scanf("%d", &a[i]); 19 for (int i = 1; i <= n; i++) b[i] = (i - 1) / x + 1; 20 for (int i = 1; i <= n; i++) { 21 scanf("%d %d %d %d", &o, &l, &r, &w); 22 if (o) printf("%d ", a[r] + f[b[r]]); 23 else add(l, r, w); 24 } 25 return 0; 26 }
2、分块入门2(传送门:https://loj.ac/problem/6278)
题面:给出一个长为 n 的数列,以及 n 个操作,操作涉及区间加法,询问区间内小于某个值 x 的元素个数。
// 对于题目本身的分析就不多赘述了,因为hzwer已经分析的太好了,不在关公们面前耍刀了,所以需要具体的做题思路可转至hzwer的博客(见上)。
hzwer原代码使用了vector,正好去了解了下vector(https://www.cnblogs.com/jinkun113/p/10691919.html)。当然此题不用vector问题也不大,这里把两个版本都贴出来了。
不论是用普通数组还是vector,有一个需要注意的点是,对于每次的区间修改,左右端点所在的块会因修改而不再呈升序排列,所以需要重新维护。
代码 - 普通版:
1 #include <bits/stdc++.h> 2 using namespace std; 3 4 #define MAXN 50005 5 6 int n, x; 7 int a[MAXN], b[MAXN], f[MAXN], c[505][505]; 8 9 void reset(int o) { 10 memset(c[o], 0, sizeof(c[o])); 11 for (int i = (o - 1) * x + 1; i <= min(o * x, n); i++) 12 c[o][++c[o][0]] = a[i]; 13 sort(c[o] + 1, c[o] + c[o][0] + 1); 14 } 15 16 void add(int l, int r, int w) { 17 for (int i = l; i <= min(b[l] * x, r); i++) a[i] += w; 18 reset(b[l]); 19 if (b[l] != b[r]) { 20 for (int i = (b[r] - 1) * x + 1; i <= r; i++) a[i] += w; 21 reset(b[r]); 22 } 23 for (int i = b[l] + 1; i <= b[r] - 1; i++) f[i] += w; 24 } 25 26 int query(int l, int r, int w) { 27 int ans = 0; 28 for (int i = l; i <= min(b[l] * x, r); i++) 29 if (a[i] + f[b[l]] < w) ans++; 30 if (b[l] != b[r]) 31 for (int i = (b[r] - 1) * x + 1; i <= r; i++) 32 if (a[i] + f[b[r]] < w) ans++; 33 for (int i = b[l] + 1; i <= b[r] - 1; i++) { 34 int x = w - f[i]; 35 for (int j = 1; j <= c[i][0]; j++) 36 if (c[i][j] < x) ans++; 37 else break; 38 } 39 return ans; 40 } 41 42 int main() { 43 int o, l, r, w; 44 scanf("%d", &n), x = sqrt(n); 45 for (int i = 1; i <= n; i++) scanf("%d", &a[i]); 46 for (int i = 1; i <= n; i++) 47 b[i] = (i - 1) / x + 1, c[b[i]][++c[b[i]][0]] = a[i]; 48 for (int i = 1; i <= b[n]; i++) sort(c[i] + 1, c[i] + c[i][0] + 1); 49 for (int i = 1; i <= n; i++) { 50 scanf("%d %d %d %d", &o, &l, &r, &w); 51 if (!o) add(l, r, w); 52 else printf("%d ", query(l, r, w * w)); 53 } 54 return 0; 55 }
代码 - vector版:
1 #include <bits/stdc++.h> 2 using namespace std; 3 4 #define MAXN 50005 5 6 int n, x; 7 int a[MAXN], b[MAXN], f[MAXN]; 8 9 vector <int> v[505]; 10 11 void reset(int o) { 12 v[o].clear(); 13 for (int i = (o - 1) * x + 1; i <= min(o * x, n); i++) v[o].push_back(a[i]); 14 sort(v[o].begin(), v[o].end()); 15 } 16 17 void add(int l, int r, int w) { 18 for (int i = l; i <= min(b[l] * x, r); i++) a[i] += w; 19 reset(b[l]); 20 if (b[l] != b[r]) { 21 for (int i = (b[r] - 1) * x + 1; i <= r; i++) a[i] += w; 22 reset(b[r]); 23 } 24 for (int i = b[l] + 1; i <= b[r] - 1; i++) f[i] += w; 25 } 26 27 int query(int l, int r, int w) { 28 int ans = 0; 29 for (int i = l; i <= min(b[l] * x, r); i++) 30 if (a[i] + f[b[l]] < w) ans++; 31 if (b[l] != b[r]) 32 for (int i = (b[r] - 1) * x + 1; i <= r; i++) 33 if (a[i] + f[b[r]] < w) ans++; 34 for (int i = b[l] + 1; i <= b[r] - 1; i++) { 35 int x = w - f[i]; 36 ans += lower_bound(v[i].begin(), v[i].end(), x) - v[i].begin(); 37 } 38 return ans; 39 } 40 41 int main() { 42 int o, l, r, w; 43 scanf("%d", &n), x = sqrt(n); 44 for (int i = 1; i <= n; i++) scanf("%d", &a[i]); 45 for (int i = 1; i <= n; i++) 46 b[i] = (i - 1) / x + 1, v[b[i]].push_back(a[i]); 47 for (int i = 1; i <= b[n]; i++) sort(v[i].begin(), v[i].end()); 48 for (int i = 1; i <= n; i++) { 49 scanf("%d %d %d %d", &o, &l, &r, &w); 50 if (!o) add(l, r, w); 51 else printf("%d ", query(l, r, w * w)); 52 } 53 return 0; 54 }
To be continued...