莫队算法是一个非常优雅的暴力算法,一般用于给定区间求出这个区间的某种性质,如果我们知道区间\([l,r]\)的性质,我们可以以\(O(1)\)或者\(O(log)\)的极小的复杂度求出区间\([l-1,r]\)、\([l+1,r]\)、\([l,r+1]\)和\([l,r-1]\)的话,一般都可以用莫队解决了。
例子:给定\(n\)个数字,\(m\)个询问,每个询问有两个参数\(l\)和\(r\),求解区间\([l,r]\)中一共有多少个不同的数字。(luogu P1972)
莫队算法的代码非常简单,总体来说,先离线储存所有的查询区间\([l,r]\),将这些区间以某种方式(重点!后面给出)排序,我们设置一个\(nowl,nowr,ans\)表示现在的\(ans\)是区间\([nowl,nowr]\)的答案,然后移动\([nowl,nowr]\)到下一个区间\([l,r]\),就是把\(nowl\)一单位一单位地移到\(l\),\(nowr\)同\(nowl\),前提是我们知道区间\([l,r]\)向\([l+1,r]\)等等的转移公式。
那么莫队的排序算法是什么呢?
我们把这\(n\)个数分成\(\sqrt n\)块,然后我们知道对于一个区间\([l,r]\),它的左端点在\(\frac{l}{\sqrt n}\)块里(第一块标号\(0\)),给查询区间\([l,r]\)一个\(id\),\(id=\frac{l}{\sqrt n}\),然后把区间以\(id\)从小到大排序,\(id\)相同的以\(r\)从小到大排序。这样每次区间转移的速度很快!
复杂度的话,假设转移复杂度为\(O(1)\)
我们先来考虑区间\([nowl,nowr]\)--->\([l,r]\)在保证\(id\)相同(\(nowl\)与\(l\)在同一个块里),从\(nowl\)到\(l\)复杂度,最大为\(O(\sqrt n) \times O(1)\),而其实我们很容易发现\(nowl\)从上个区间到这个区间的跑路长度取平均就是\(\sqrt n\),由于它会跑\(m\)次路,复杂度就是\(O(m \sqrt n)\)了。
那么\(r\)呢?
在\(id\)相同的情况下,\(r\)是一路向右而不会回头的!因为排序就这样排的!所以每个块里面\(r\)最远跑路为\(n\),一共\(\sqrt n\)块,于是复杂度为\(O(n\sqrt n)\)。
那么总体复杂度就是相加得到的\(O(n\sqrt n)\)
这里有一道luoguP1972的升级版,luoguP2709,是统计的每个数字的个数,并且开方求和。
这题很明显有\(O(1)\)的转移方法:\(num[i]\)表示\(i\)数字在\([nowl,nowr]\)区间的个数,那么一旦移动\(nowl\)或\(nowr\),我们可以对单个数字\(k\)进行更新,可能是\(num[k]++\)(也可能是\(--\)),那么这个对答案的更新也很明显,\(ans=ans-num[k]^2+(num[k]+1(\)或\(-1))^2\)。
那么这题可以用莫队做,一下是你们基本不可能读得下去的代码:
#include <bits/stdc++.h>
using namespace std;
int n,m,k;
int len;
int z,y;
long long now;
int a[100005];
long long c[100005];
long long ans[100005];
struct ha
{
int l,r,id,used;
}q[100005];
bool cmp(const ha &aa,const ha &bb)
{
if(aa.id==bb.id)return aa.r<bb.r;
return aa.id<bb.id;
}
void deal(int x,int v)
{
now-=c[x]*c[x];
c[x]+=v;
now+=c[x]*c[x];
}
void work()
{
z=1,y=1;
deal(a[1],1);
for(int i=1;i<=m;i++)
{
while(q[i].r>y)
{
y++;
deal(a[y],1);
}
while(q[i].l<z)
{
z--;
deal(a[z],1);
}
while(q[i].r<y)
{
deal(a[y],-1);
y--;
}
while(q[i].l>z)
{
deal(a[z],-1);
z++;
}
ans[q[i].used]=now;
}
}
int main()
{
scanf("%d%d%d",&n,&m,&k);
len=sqrt(n);
for(int i=1;i<=n;i++)
scanf("%d",&a[i]);
for(int i=1;i<=m;i++)
{
scanf("%d%d",&q[i].l,&q[i].r);
q[i].id=q[i].l/len;
q[i].used=i;
}
sort(q+1,q+1+m,cmp);
work();
for(int i=1;i<=m;i++)
printf("%lld\n",ans[i]);
return 0;
}
这是很基础的莫队了,还有较难的,带修改的莫队:例题luogu P1903。
这题中,我们需要吧区间变成三个参数\([l,r,k]\),\(k\)是时间戳,我们转移区间的时候\(k\)也是会变的,前一个\(k\)和现在的\(k\)之间就包涵了修改操作,其实整体依然很简单,我们把\(l\),\(r\)都拿去除以分块,然后以\(idl\)(表示\(l\)所在的块)排序,以\(idr\)为第二关键字排序,最后以\(k\)排序,此处分块也不是以\(\sqrt n\)了,而是\(n^{\frac{2}{3}}\),\(l\)移动的复杂度是\(O(mn^{\frac{2}{3}})\),\(r\)也是\(O(mn^{\frac{2}{3}})\),而\(k\)的复杂度求解,可以看下这篇博客(码字太累了!)
https://blog.csdn.net/chenxiaoran666/article/details/82220385
那么转移和前面的一样,只不过还要同时移动\(k\),修改时简单判断就可以了。
然后直接上代码吧。
#include <bits/stdc++.h>
using namespace std;
int n,m,qm;
int len;
int z,y,now,nowk=0;
int num[1000005];
int a[100005];
char c[10];
int updatex[100005];
int updatev[100005];
int lastv[100005];
int ans[100005];
struct ha
{
int l,r,idl,idr,used;
}q[100005];
int erfen(long long l,long long r)
{
if(l==r)
return l;
int mid=(l+r)/2;
if(((long long)mid)*mid*mid>=((long long)n)*n)return erfen(l,mid);
else return erfen(mid+1,r);
}
bool cmp(const ha &aa,const ha &bb)
{
if(aa.idl!=bb.idl)return aa.idl<bb.idl;
if(aa.idr!=bb.idr)return aa.idr<bb.idr;
return aa.used<bb.used;
}
void deal(int x,int v)
{
if(num[a[x]]==0&&v==1)
now++;
num[a[x]]+=v;
if(num[a[x]]==0&&v==-1)
now--;
}
void add(int i,int x,int v)
{
if(x==0&&v==0)return ;
lastv[i]=a[x];
if(x<=y&&x>=z)
{
num[a[x]]--;
if(num[a[x]]==0)
now--;
a[x]=v;
if(num[v]==0)
now++;
num[v]++;
}
else
a[x]=v;
}
void del(int i,int x,int v)
{
if(x==0&&v==0)return ;
if(x<=y&&x>=z)
{
num[a[x]]--;
if(num[a[x]]==0)
now--;
a[x]=lastv[i];
if(num[lastv[i]]==0)
now++;
num[lastv[i]]++;
}
else
a[x]=lastv[i];
}
void work()
{
z=1,y=1;
now=1;
num[a[1]]++;
for(int i=1;i<=qm;i++)
{
while(q[i].r>y)
{
y++;
deal(y,1);
}
while(q[i].l<z)
{
z--;
deal(z,1);
}
while(q[i].r<y)
{
deal(y,-1);
y--;
}
while(q[i].l>z)
{
deal(z,-1);
z++;
}
while(nowk<q[i].used)
{
nowk++;
add(nowk,updatex[nowk],updatev[nowk]);
}
while(nowk>q[i].used)
{
del(nowk,updatex[nowk],updatev[nowk]);
nowk--;
}
ans[q[i].used]=now;
}
}
int main()
{
memset(ans,0x3f,sizeof(ans));
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++)
scanf("%d",&a[i]);
len=erfen(1,n);
for(int i=1;i<=m;i++)
{
scanf("%s",c+1);
if(c[1]=='Q')
{
qm++;
q[qm].used=i;
scanf("%d%d",&q[qm].l,&q[qm].r);
q[qm].idl=q[qm].l/len;
q[qm].idr=q[qm].r/len;
}
else
scanf("%d%d",&updatex[i],&updatev[i]);
}
sort(q+1,q+1+qm,cmp);
work();
for(int i=1;i<=m;i++)
{
if(ans[i]<1e8)
printf("%d\n",ans[i]);
}
return 0;
}