zoukankan      html  css  js  c++  java
  • 莫队算法解析

    首先,本人能力有限,不一定能够讲得很清楚,但我尽力让读者看懂

    先来看一个题目:

    HH的项链(洛谷1972)

    题目大意:给你一个长度为n的序列,序列中数字,有m个询问,每个询问区间为[l,r],求每个询问的区间内有多少个不同的数字

    首先看到这道题目让你用最暴力的方法做你会怎么做,肯定是O(n*m)的暴力,代码如下

     

    #include<iostream>
    #include<cstdio>
    #include<cstring>
    #include<string>
    #include<cstdlib>
    #include<cmath>
    #include<algorithm>
    #define in(i) (i=read())
    using namespace std;
    typedef long long lol;
    lol read()
    {
        lol ans=0,f=1;
        char i=getchar();
        while(i<'0'||i>'9') {if(i=='-') f=-1; i=getchar();}
        while(i>='0'&&i<='9') {ans=(ans<<3)+(ans<<1)+i-'0';i=getchar();}
        return ans*f;
    }
    int c[500010];
    int vis[1000010];
    int main()
    {
        int n,m;
        in(n);
        for(int i=1;i<=n;i++) in(c[i]);
        in(m);
        for(int i=1;i<=m;i++) {
            memset(vis,0,sizeof(vis));
            int ans=0;
            int l,r;
            in(l);in(r);
            for(int j=l;j<=r;j++)
                if(!vis[c[j]]) {
                    ans++;
                    vis[c[j]]=1;
                }
            cout<<ans<<endl;
        }
    }
    View Code

    但是这样很显然对于题目数据是过不了的,下面我们来看一看另外一个暴力方法

    #include<iostream>
    #include<cstdio>
    #include<cstring>
    #include<string>
    #include<cstdlib>
    #include<cmath>
    #include<algorithm>
    #define in(i) (i=read())
    using namespace std;
    typedef long long lol;
    lol read()
    {
        lol ans=0,f=1;
        char i=getchar();
        while(i<'0'||i>'9')
        {
            if(i=='-') f=-1;
            i=getchar();
        }
        while(i>='0'&&i<='9')
        {
            ans=(ans<<3)+(ans<<1)+i-'0';
            i=getchar();
        }
        return ans*f;
    }
    int ans=0;
    int c[500010];
    int cnt[1000010];
    int add(int x){
        cnt[c[x]]++;
        if(cnt[c[x]]==1) ans++;
    }
    int remove(int x){
        cnt[c[x]]--;
        if(!cnt[c[x]]) ans--;
    }
    int main()
    {
        int n,m; in(n);
        for(int i=1;i<=n;i++) in(c[i]);
        in(m);
        for(int i=1,curl=1,curr=0;i<=m;i++) {
            int l,r; in(l);in(r);
            while(curl<l) remove(curl++);
            while(curl>l) add(--curl);
            while(curr<r) add(++curr);
            while(curr>r) remove(curr--);
            cout<<ans<<endl;
        }
        return 0;
    }
    View Code

     

    6
    1 2 3 4 3 5
    3
    1 2
    3 5
    2 6
    
    
    
    2
    2
    4

    我们可以拿这组样例来模拟一下.

    关于这个add函数和remove函数,我们可以这样理解,最开始一组[l,r],curl和curr肯定最后都在1和2,那么我们之后的解都由这一组解转移而来

    在询问过第一组数据后,我们可以看成区间[1,2]都访问过1次,接下来对于第二组数据[3,5],我们用一个指针curl,curr以及一个添加函数add,删除函数remove来维护,每进一次add函数,看成访问次数+1,每进一次remove,看成访问次数-1,那么我们将[1,3)全部remove一次,将[2,5]全部add一次,中间差的部分是不是还是可以看成访问次数=0,这样一来统计的就还是区间[3,5]的数字,如果还是不理解,自己可以模拟一下

    其实这就是莫队的核心代码,可以看出这个代码的时间复杂度是和curl和curr的移动次数有关的,我们当然希望移动次数越小越好,但是每次询问的l和r无法保证是递增的,怎么办呢

    我们可以将其分块再排序

    我们将这个序列分成n个区间,再按询问区间所在的块的序号为第一关键字,右端点为第二关键字从小到大排序,在用上面的程序来玩,就可以保证这两个指针不会走多余的路了

    举个例子假设现在我们有3个长度为3的段{[0,2],[3,5],[6,8]}

    有一些区间,{[0,3],[1,7],[2,8],[7,8],[4,8],[4,4],[1,2]}

    首先按左端点所在块的序号排序,那么顺序变成

    {[0,3],[1,7],[2,8],[1,2],[4,8],[4,4],[7,8]}

    再按右端点排序

    {[1,2],[0,3],[1,7],[2,8,[4,4],[4,8],[7,8]}

    接下来还要管吗,直接用上面的程序玩就行了

    时间复杂度证明

     

    右端点移动:
    首先我们考虑一个块里面的转移情况
    由于一个块里面的询问都按右端点排序
    所以我们右端点在一个块里面最多移动n次
    有 √n个块,那么同一个块内的右端点移动最多就是O(n√n)
    然后考虑从一个块到另一个块导致的右端点变化
    最坏情况,右端点由n到1,那么移动n次
    有 √n个块
    那么从一个块到另一个块的事件只会发生O(√n)次……
    所以这种右端点移动的次数也是O(n√n)次
    没有别的事件导致右端点移动了
    左端点移动:
    同一个块里面,由于左端点都在一个长度为O(√n)的区间里面
    所以在同一块里面移动一次,左端点最多变化O(√n)
    总共有n个询问……
    所以同一块里面的移动最多n次
    那么同一个块里面的左端点变化最多是O(n√n)的
    考虑越块
    每由第i个块到第i+1个块,左端点最坏加上O(√n)
    总共能加上O(√n)次
    所以跨越块导致的左端点移动是O(n)的
    综上,分块做法是O(n∗√n)。

     

     

    那么这道题就好解了,代码如下

     

    #include<iostream>
    #include<cstdio>
    #include<cstring>
    #include<string>
    #include<cstdlib>
    #include<cmath>
    #include<algorithm>
    #define in(i) (i=read())
    using namespace std;
    typedef long long lol;
    lol read()
    {
        lol ans=0,f=1;
        char i=getchar();
        while(i<'0'||i>'9')
        {
            if(i=='-') f=-1;
            i=getchar();
        }
        while(i>='0'&&i<='9')
        {
            ans=(ans<<3)+(ans<<1)+i-'0';
            i=getchar();
        }
        return ans*f;
    }
    struct query
    {
        lol l,r,id,pos;
    }e[200010];
    lol tot;
    lol c[500010],cnt[1000010],ans[200010];
    lol cmp(query a,query b){
        return (a.pos==b.pos)?(a.r<b.r):(a.pos<b.pos);
    }
    void add(lol x){
        cnt[c[x]]++;
        if(cnt[c[x]]==1) tot++;
    }
    void remove(lol x){
        cnt[c[x]]--;
        if(cnt[c[x]]==0) tot--;
    }
    int main()
    {
        lol n,m; in(n);
        lol block=(lol)sqrt(n);//分成√n个块
        for(lol i=1;i<=n;i++) in(c[i]);
        in(m);
        for(lol i=1;i<=m;i++) {
            in(e[i].l);in(e[i].r);
            e[i].id=i;
            e[i].pos=(e[i].l-1)/block+1;
        }
        sort(e+1,e+1+m,cmp);
        for(lol i=1,curl=1,curr=0;i<=m;i++) {
            lol l=e[i].l,r=e[i].r;
            while(curl<l) remove(curl++);
            while(curl>l) add(--curl);
            while(curr<r) add(++curr);
            while(curr>r) remove(curr--);
            ans[e[i].id]=tot;
        }
        for(lol i=1;i<=m;i++)
            printf("%lld
    ",ans[i]);
        return 0;
    }
    View Code

    (有错误希望大家指出)

    希望大家看了这篇博客后可以更好的理解莫队

     

    博主蒟蒻,随意转载.但必须附上原文链接
    http://www.cnblogs.com/real-l/
  • 相关阅读:
    损失函数
    numpy中的broadcast
    混合模型
    贝叶斯学习
    python3中输出不换行
    C++11 实现 argsort
    Python中的闭包
    C语言 fread()与fwrite()函数说明与示例
    DFT与傅里叶变换的理解
    MISRA C:2012 Dir-1.1(只记录常犯的错误和常用的规则)Bit-fields inlineC99,NOT support in C90 #pragma
  • 原文地址:https://www.cnblogs.com/real-l/p/8782196.html
Copyright © 2011-2022 走看看