在开始学习莫队之前,照例先甩一道例题:BZOJ 1878 HH的项链。
题意:求区间内数的个数,相同的数只算一次。
在我关于这道题的上一篇题解中,我使用了主席树来在线做这道题;在洛谷的一道类似题中,我使用了分块;而如果不要求在线,这道题还有一种极其好写的方法——莫队。
什么是莫队?
莫队不是一种叫做莫的队列(我第一次听到这个名字时竟然是这么理解的 -_-|||),它是以发明人前国家队队长莫涛——“莫队”的名字命名的。
它是一种传说中“能解决一切区间问题”的算法。首先,我们先来学习最简单的莫队——可离线、无修改的莫队。
可离线、无修改的莫队
莫队算法的精髓就是通过合理地对询问排序,然后以较优的顺序暴力回答每个询问。处理完一个询问后,可以使用它的信息得到下一个询问区间的答案。
考虑这个问题:对于上面这道题,已知区间 ([1, 5]) 的答案,求 ([2, 6]) 的答案,如何暴力求?
当然,可以将区间 ([2, 6]) 从头到尾扫一遍,直接求出答案,也可以在区间 ([1, 5]) 的基础上,去掉位置(1)(即将左端点右移一位),加上位置(6)(即将右端点右移一位),得到区间 ([2, 6]) 的答案。
在莫队算法中,我们可以使用第二种求答案的方法。至于为什么要用这个貌似与前面那种方法复杂度毫无区别的方法?当然是因为经过“合理的排序”后,这种方法可以被优化啦。
接下来我们还需要考虑一个问题:如何“合理地对询问排序”?
莫队提供了这样一个排序方案:将原序列以(sqrt n)为一块进行分块,排序第一关键字是询问的左端点所在块的编号,第二关键字是询问的右端点本身的位置,都是升序。然后我们用上面提到的“移动当前区间左右端点”的方法,按顺序求每个询问区间的答案,移动每一个询问区间左右端点可以求出下一个区间的答案。
具体的核心部分代码:
sort(q + 1, q + m + 1); //将询问排序
int ql = 1, qr = 0; //初始区间是一个空区间
for(int i = 1; i <= m; i++){
while(pl < q[i].l) del(a[pl++]); //
while(pl > q[i].l) add(a[--pl]);
while(pr < q[i].r) add(a[++pr]);
while(pr > q[i].r) del(a[pr--]);
ans[q[i].id] = sum;
}
这样就可以求出答案了!
——可是,这样做的复杂度是什么?
显然,每次移动左端点(或右端点)的复杂度都是(O(1))。那么只需要知道左右端点分别移动了多少次,就可以知道复杂度了!
对于右端点:当当前询问的左端点在同一块时,右端点都是有序的,那么右端点最多会从1一直移动到n;两个询问左端点在不同块(即“跨块”)时,最多从n一下子移回1。两种都是(O(n))的,总共有(sqrt n)块,所以复杂度是(O(n sqrt n))。
对于左端点:当当前询问的左端点在同一块时,注意左端点不是有序的,那么一次最多从块的一端移到另一端,复杂度(O(sqrt n)),总共有n个询问的话,复杂度是(O(sqrt n))。跨块时也类似,移动距离也是(O(sqrt n))。
综上,莫队的复杂度是(O(sqrt n))!
莫队的一大优点是:代码思路极其简单,尤其是序列上无需修改的莫队,核心代码只有五行:
while(pl < q[i].l) del(a[pl++]);
while(pl > q[i].l) add(a[--pl]);
while(pr < q[i].r) add(a[++pr]);
while(pr > q[i].r) del(a[pr--]);
ans[q[i].id] = sum;
其中 (pl)、(pr)一开始是上一次询问的左右端点,结束后变成了这一次询问的左右端点。(sum)维护当前区间([pl, pr])的答案。
BZOJ 1878 HH的项链我的AC代码:
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <set>
using namespace std;
typedef long long ll;
#define space putchar(' ')
#define enter putchar('
')
template <class T>
void read(T &x){
char c;
bool op = 0;
while(c = getchar(), c < '0' || c > '9')
if(c == '-') op = 1;
x = c - '0';
while(c = getchar(), c >= '0' && c <= '9')
x = x * 10 + c - '0';
if(op) x = -x;
}
template <class T>
void write(T x){
if(x < 0) x = -x, putchar('-');
if(x >= 10) write(x / 10);
putchar('0' + x % 10);
}
const int N = 50005, M = 200005, B = 233;
int n, m, a[N], sum, ans[M], cnt[1000005];
#define bel(x) ((x - 1) / B + 1)
struct query {
int id, l, r;
bool operator < (const query &b) const{
return bel(l) == bel(b.l) ? r < b.r : l < b.l;
}
} q[M];
void add(int x){
if(!cnt[x]) sum++;
cnt[x]++;
}
void del(int x){
cnt[x]--;
if(!cnt[x]) sum--;
}
int main(){
read(n);
for(int i = 1; i <= n; i++) read(a[i]);
read(m);
for(int i = 1; i <= m; i++)
q[i].id = i, read(q[i].l), read(q[i].r);
sort(q + 1, q + m + 1);
int pl = 1, pr = 0;
for(int i = 1; i <= m; i++){
while(pl < q[i].l) del(a[pl++]);
while(pl > q[i].l) add(a[--pl]);
while(pr < q[i].r) add(a[++pr]);
while(pr > q[i].r) del(a[pr--]);
ans[q[i].id] = sum;
}
for(int i = 1; i <= m; i++)
write(ans[i]), enter;
return 0;
}
可以单点修改的莫队
写完了上面这道题,可以发现:普通的莫队算法没有支持修改。那么如何改造该算法使它支持修改呢?
莫队算法被称为“优雅的暴力”,那么我们改造莫队算法的思路也只有一个:改造询问排序的方式,然后继续暴力。
这一次,排序的方式是:以(n^{frac{2}{3}})为一块,一共将序列分为(n^{frac{1}{3}})块。排序第一关键字是左端点所在块编号,第二关键字是右端点所在块编号,第三关键字是时间。
每次回答询问时,先从上一个询问的时间“穿越”到当前询问的时间:如果当前询问的时间更靠后,则顺序执行所有修改,直到达到当前询问时间;如果当前询问的时间更靠前,则“时光倒流”,还原所有多余的修改。进行推移时间的操作时,如果涉及到当前区间内的位置的修改,要对答案进行相应的维护。
接下来我们来简要地证明一下复杂度!(不是非常关心证明的同学,可以直接跳到结论部分……)
推移时间、移动左端点、移动右端点的操作都是(O(1))的。
对于时间的移动,对于左右端点所在块不变的情况,时间是单调向右移的,总共(O(n)); 左右端点之一所在块改变,时间最多从(n)直接移回(1),复杂度(O(n));左右端点所在块各有(O(n^{frac{1}{3}}))种,两两组合有(O(n^{frac{2}{3}}))种,每种都是(O(n)),总复杂度是(O(n^{frac{5}{3}}))。
对于右端点的移动,在左右端点所在块不变时,每次最多移动(n^{frac{2}{3}}),一共最多有(n)次右端点的移动,复杂度是(O(n^{frac{5}{3}}));当左端点所在块改变时,右端点最多从(n)一直移动到(1),距离是(n),最多有(n^{frac{1}{3}})次这样的移动,复杂度是(O(n^{frac{4}{3}}));总共右端点移动的复杂度是(O(n^{frac{5}{3}}))。
对于左端点的移动,在左端点块不变时,一次移动距离最多(n^{frac{2}{3}}),总共(O(n^{frac{5}{3}}))。而跨块时,由于左端点所在块是单调向右移动的,复杂度最大的情况就是每跨一个块都是从前一个块的最左侧跑到后一个块的最右侧,距离(O(n^{frac{1}{3}})),总复杂度(O(n))。所以总共左端点移动的复杂度是(O(n^{frac{5}{3}}))。
结论:上述排序方法实现的莫队复杂度是(O(n^{frac{5}{3}}))。
想试一试?可以做一下这道题:BZOJ 2120 数颜色。这道题也可以使用分块+块内排序二分来做,我也用这个方法写过一篇题解,欢迎阅读。
这道题我的莫队AC代码:
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long ll;
#define space putchar(' ')
#define enter putchar('
')
template <class T>
void read(T &x){
char c;
bool op = 0;
while(c = getchar(), c < '0' || c > '9')
if(c == '-') op = 1;
x = c - '0';
while(c = getchar(), c >= '0' && c <= '9')
x = x * 10 + c - '0';
if(op) x = -x;
}
template <class T>
void write(T x){
if(x < 0) x = -x, putchar('-');
if(x >= 10) write(x / 10);
putchar('0' + x % 10);
}
const int N = 10005, M = 1000005, B = 464;
int n, m, pl = 1, pr = 0, cur, res, ans[N], a[N], cnt[M];
int idxC, idxQ, tim[N], pos[N], val[N], pre[N];
#define bel(x) (((x) - 1) / B + 1)
struct query {
int id, tim, l, r;
bool operator < (const query &b) const {
if(bel(l) != bel(b.l)) return l < b.l;
if(bel(r) != bel(b.r)) return r < b.r;
return id < b.id;
}
} q[N];
void change_add(int cur){
if(pos[cur] >= pl && pos[cur] <= pr){
cnt[a[pos[cur]]]--;
if(!cnt[a[pos[cur]]]) res--;
}
pre[cur] = a[pos[cur]];
a[pos[cur]] = val[cur];
if(pos[cur] >= pl && pos[cur] <= pr){
if(!cnt[a[pos[cur]]]) res++;
cnt[a[pos[cur]]]++;
}
}
void change_del(int cur){
if(pos[cur] >= pl && pos[cur] <= pr){
cnt[a[pos[cur]]]--;
if(!cnt[a[pos[cur]]]) res--;
}
a[pos[cur]] = pre[cur];
if(pos[cur] >= pl && pos[cur] <= pr){
if(!cnt[a[pos[cur]]]) res++;
cnt[a[pos[cur]]]++;
}
}
void change(int now){
while(cur < idxC && tim[cur + 1] <= now) change_add(++cur);
while(cur && tim[cur] > now) change_del(cur--);
}
void add(int p){
if(!cnt[a[p]]) res++;
cnt[a[p]]++;
}
void del(int p){
cnt[a[p]]--;
if(!cnt[a[p]]) res--;
}
bool isQ(){
char op[2];
scanf("%s", op);
return op[0] == 'Q';
}
int main(){
read(n), read(m);
for(int i = 1; i <= n; i++) read(a[i]);
for(int i = 1; i <= m; i++){
if(isQ()) idxQ++, q[idxQ].id = idxQ, q[idxQ].tim = i, read(q[idxQ].l), read(q[idxQ].r);
else tim[++idxC] = i, read(pos[idxC]), read(val[idxC]);
}
sort(q + 1, q + idxQ + 1);
for(int i = 1; i <= idxQ; i++){
change(q[i].tim);
while(pl > q[i].l) add(--pl);
while(pr < q[i].r) add(++pr);
while(pl < q[i].l) del(pl++);
while(pr > q[i].r) del(pr--);
ans[q[i].id] = res;
}
for(int i = 1; i <= idxQ; i++)
write(ans[i]), enter;
return 0;
}
树上莫队
在序列中,莫队算法号称“可以解决一切区间问题”;而把莫队算法搬到树上,它在某种程度上也可以“解决一切树上路径问题”。
学习树上莫队,需要以下预备知识:
- LCA(最近公共祖先)
- 莫队
- 能 AC BZOJ 1086 王室联邦
首先,我们要对树进行分块!如何分块?请参考 BZOJ 1086 的分块方法,欢迎参考我的题解。这个方法可以保证每一块的大小都在([B, 3B])之间。(我不知道为什么世界上会有这样一道题,简直是出题人为以后打算学树上莫队的选手提供完美的练习题啊……)
然后我们还是要对所有询问进行排序。排序依据是(假设我们做的是带修改树上莫队,一块的大小是(n^{frac{2}{3}}))左端点所在块的编号、右端点所在块的编号、时间。
与上面的“数颜色”这道题类似,这道题也是使用“时间推移”和“时间倒流”的技能实现修改操作。序列莫队中的“左右端点”的概念,在树上莫队中对应着“起点终点”。
最后一个重要的问题就是:如何移动起点终点?
在序列中,左右端点的移动方式是显然的,一个端点的移动只有两个方向——左和右,而它们带来的影响也是显然的——区间增加一个元素或删除一个元素。
然而树上莫队却不是非常显然……最佳的方案是:维护一个(vis)布尔数组,记录每个节点是否在当前处理的路径上(LCA非常难办,我们在维护路径上的点时不包括LCA,求答案的时候临时把LCA加上)。每次从上一个询问((u_s, v_s))转移到当前询问((u_t, v_t))时,我们要做的是——把路径((u_s, u_t))和((v_s, v_t))上的点的vis逐个取反,同时对应地维护答案。
对于上面的做法的正确性,VFleaKing的博客中有证明,证明部分摘录如下(Xor表示类似异或的操作,即节点出现两侧会消掉):
T(v, u) = S(root, v) xor S(root, u)(摘者注:显然等式右侧是u到v的路径上除lca以外的点)
观察将curV移动到targetV前后T(curV, curU)变化:
T(curV, curU) = S(root, curV) xor S(root, curU)
T(targetV, curU) = S(root, targetV) xor S(root, curU)
取对称差:
T(curV, curU) xor T(targetV, curU)= (S(root, curV) xor S(root, curU)) xor (S(root, targetV) xor S(root, curU))
由于对称差的交换律、结合律:
T(curV, curU) xor T(targetV, curU)= S(root, curV) xor S(root, targetV)
两边同时xor T(curV, curU):
T(targetV, curU)= T(curV, curU) xor S(root, curV) xor S(root, targetV)
发现最后两项很爽……哇哈哈
T(targetV, curU)= T(curV, curU) xor T(curV, targetV)
当然,你也可以画个图,感性地理解一下这种操作……
那么接下来我们看一道例题!WC2013, BZOJ3052, UOJ58 糖果公园
我在这里贴一下我的代码供大家参考。
#include <cstdio>
#include <cmath>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long ll;
template <class T>
void read(T &x){
char c;
bool op = 0;
while(c = getchar(), c < '0' || c > '9')
if(c == '-') op = 1;
x = c - '0';
while(c = getchar(), c >= '0' && c <= '9')
x = x * 10 + c - '0';
if(op) x = -x;
}
template <class T>
void write(T x){
if(x < 0) putchar('-'), x = -x;
if(x >= 10) write(x / 10);
putchar('0' + x % 10);
}
#define space putchar(' ')
#define enter putchar('
')
const int N = 200005, B = 2005;
int n, maxcol, m, bel[N], idx, stk[N], top, pu, pv, cur;
int cntQ, cntC, tim[N], pos[N], newx[N], col[N], pre[N], cnt[N];
ll val[N], wei[N], res, ans[N];
int ecnt, go[2*N], nxt[2*N], adj[N];
int fa[N], lg[2*N], dep[N], seq[2*N], seq_cnt, seq_pos[N], mi[2*N][20];
bool vis[N];
struct query {
int id, tim, u, v;
bool operator < (const query &b) const {
if(bel[u] != bel[b.u]) return bel[u] < bel[b.u];
if(bel[v] != bel[b.v]) return bel[v] < bel[b.v];
return tim < b.tim;
}
} q[N];
void add(int u, int v){
go[++ecnt] = v;
nxt[ecnt] = adj[u];
adj[u] = ecnt;
}
void dfs(int u, int pre){
dep[u] = dep[pre] + 1, fa[u] = pre;
seq[++seq_cnt] = u, seq_pos[u] = seq_cnt;
int st = top;
for(int e = adj[u], v; e; e = nxt[e])
if(v = go[e], v != pre){
dfs(v, u);
seq[++seq_cnt] = u;
if(top - st > B){
idx++;
while(top > st) bel[stk[top--]] = idx;
}
}
stk[++top] = u;
}
int Min(int a, int b){
return dep[a] < dep[b] ? a : b;
}
void lca_init(){
for(int i = 1, j = 0; i <= seq_cnt; i++)
lg[i] = i == (1 << (j + 1)) ? ++j : j;
for(int i = 1; i <= seq_cnt; i++) mi[i][0] = seq[i];
for(int j = 1; (1 << j) <= seq_cnt; j++)
for(int i = 1; i + (1 << j) - 1 <= seq_cnt; i++)
mi[i][j] = Min(mi[i][j - 1], mi[i + (1 << (j - 1))][j - 1]);
}
int lca(int u, int v){
u = seq_pos[u], v = seq_pos[v];
if(u > v) swap(u, v);
int j = lg[v - u + 1];
return Min(mi[u][j], mi[v - (1 << j) + 1][j]);
}
void reverse(int u){
if(vis[u]) res -= wei[cnt[col[u]]] * val[col[u]], cnt[col[u]]--;
else cnt[col[u]]++, res += wei[cnt[col[u]]] * val[col[u]];
vis[u] ^= 1;
}
void move(int u, int v){
int w = lca(u, v);
while(u != w) reverse(u), u = fa[u];
while(v != w) reverse(v), v = fa[v];
}
void travel_ahead(){
bool flag = 0;
if(vis[pos[cur]]) flag = 1, reverse(pos[cur]);
pre[cur] = col[pos[cur]];
col[pos[cur]] = newx[cur];
if(flag) reverse(pos[cur]);
}
void travel_back(){
bool flag = 0;
if(vis[pos[cur]]) flag = 1, reverse(pos[cur]);
col[pos[cur]] = pre[cur];
if(flag) reverse(pos[cur]);
}
void time_travel(int tar){
while(cur < cntC && tim[cur + 1] <= tar) cur++, travel_ahead();
while(cur && tim[cur] > tar) travel_back(), cur--;
}
int main(){
read(n), read(maxcol), read(m);
for(int i = 1; i <= maxcol; i++) read(val[i]);
for(int i = 1; i <= n; i++) read(wei[i]);
for(int i = 1, u, v; i < n; i++)
read(u), read(v), add(u, v), add(v, u);
for(int i = 1; i <= n; i++) read(col[i]);
for(int i = 1, op; i <= m; i++){
read(op);
if(op) q[++cntQ].tim = i, q[cntQ].id = cntQ, read(q[cntQ].u), read(q[cntQ].v);
else tim[++cntC] = i, read(pos[cntC]), read(newx[cntC]);
}
dfs(1, 0);
lca_init();
while(top) bel[stk[top--]] = idx;
sort(q + 1, q + cntQ + 1);
pu = pv = 1;
for(int i = 1; i <= cntQ; i++){
time_travel(q[i].tim);
move(pu, q[i].u), pu = q[i].u;
move(pv, q[i].v), pv = q[i].v;
reverse(lca(pu, pv));
ans[q[i].id] = res;
reverse(lca(pu, pv));
}
for(int i = 1; i <= cntQ; i++)
write(ans[i]), enter;
return 0;
}