zoukankan      html  css  js  c++  java
  • BM(Berlekamp-Massey)算法

    线性递推的题目区域赛里还是挺多的,还是有必要学一下


    ~ BM(Berlekamp-Massey)算法 ~

    有一个$n$阶线性递推$f$,想要计算$f(m)$,有一种常用的办法是矩阵快速幂,复杂度是$O(n^3logm)$

    在不少情况下这已经够用了,但是如果$n$比较大、到了$10^3$级别,这就不太适用了

    而BM算法能将这个复杂度压低到$O(n^2logm)$,若加上NTT优化的话能做到$O(n^2+nlognlogm)$,十分厉害

    这个算法的核心是将$f(m)$用递推的前$n$项表示

    即,已知$f(0),...,f(n-1)$和递推式$f(m)=a_0f(m-1)+...+a_{n-1}f(m-n)$,该算法是求出系数$W_0,...,W_{n-1}$,使得$f(m)=W_0f(n-1)+...+W_{n-1}f(0)$

    看似无从下手?实际上只要大力展开就行了

    根据定义,有(只是写成$sum$的形式而已)

    [f(m)=sum_{i=0}^{n-1}a_i f(m-1-i)]

    而对于每一项再次展开,即

    [f(m-1-i)=sum_{j=0}^{n-1}a_j f(m-1-i-1-j)]

    全部代入,能得到

    [f(m)=sum_{i=0}^{n-1}sum_{j=0}^{n-1}a_ia_j f(m-2-i-j)]

    把式子写的更好看一点,就是

    [f(m)=sum_{k=0}^{2n-2}sum_{i+j=k}a_ia_j f(m-2-k)]

    这样做之后有什么用呢?

    在原本的递推式中,$f(m)$可以通过$f(m-1),...,f(m-n)$这$n$个项表示

    各项展开后,就可以通过$f(m-2),...,f(m-2n)$表示

    事实上,我们可以再依次对$f(m-i),2leq ileq n$展开,并将系数向$f(m-i-1),...,f(m-i-n)$并入,最终就能把原递推式通过$f(m-n-1),...,f(m-2n)$这$n$项表示

    于是可以得到一个新的$n$阶递推式,记为$f(m)=b_0f(m-n+1),...,b_{n-1}f(m-2n)$

    再用新递推式将各项展开,就可以通过$f(m-2n-2),...,f(m-4n)$表示

    再用原递推式展开$f(m-2n-i),2leq ileq n$并向前合并系数,最终就能把原递推式通过$f(m-3n+1),...,f(m-4n)$这$n$项表示

    之后都是类似的了,不再赘述

    有了上面的思路,就可以用类似快速幂的方法,得到$f(m)=W_0f(m-(k-1)n+1),...,W_{n-1}f(m-kn)$这样的展开式,其中$m-kn<n$

    余数$m-kn$是我们不喜欢的,但也没有必要整体再向前推,一开始计算时算出$f(0),...,f(2n-1)$就够了

    按照上述思路能这样实现:

    #include <cstdio>
    #include <cstring>
    using namespace std;
    
    typedef long long ll;
    const int MOD=1000000007;
    const int N=1005;
    
    
    int n,m;
    int a[N];
    int f[N<<1];
    
    int tmp[N<<1];
    
    void mul(int *y,int *x)
    {
        memset(tmp,0,sizeof(tmp));
        
        for(int i=0;i<n;i++)
            for(int j=0;j<n;j++)
                tmp[i+j]=(tmp[i+j]+ll(y[i])*x[j])%MOD;
        
        for(int i=0;i<n-1;i++)
            for(int j=0;j<n;j++)
                tmp[i+j+1]=(tmp[i+j+1]+ll(tmp[i])*a[j])%MOD;
        
        for(int i=0;i<n;i++)
            y[i]=tmp[i+n-1]; 
    }
    
    int w[N<<1],x[N<<1];
    
    int BM()
    {
        if(m<(n<<1))
            return f[m];
        
        for(int i=0;i<n;i++)
            x[i]=a[i],w[i]=a[i];
        
        int t=(m-n)/n;
        int rem=m-n-t*n;
        
        while(t)
        {
            if(t&1)
                mul(w,x);
            mul(x,x);
            t>>=1;
        }
        
        int res=0;
        for(int i=0;i<n;i++)
            res=(res+ll(w[i])*f[rem+n-i-1])%MOD;
        return res;
    }
    
    int main()
    {
        scanf("%d%d",&n,&m);
        for(int i=0;i<n;i++)
            scanf("%d",&a[i]);
        for(int i=0;i<n;i++)
            scanf("%d",&f[i]);
        for(int i=n;i<(n<<1);i++)
            for(int j=1;j<=n;j++)
                f[i]=(f[i]+ll(a[j-1])*f[n-j])%MOD;
        
        printf("%d
    ",BM());
        return 0;
    }
    View Code

    想做的更快的话,一个是要写NTT,另一个是合并系数会比较困难,待补


    因为这题学的BM:牛客ACM 882B ($Eddy$ $Walker$ $2$)

    $m ightarrow infty$时,$f(m) ightarrow frac{2}{k+1}$ (并不会证...)

    从rls那里学了一个证明:

    走$k$步,期望能走的长度是$1+2+...+k=frac{k(k+1)}{2}$

    那么在这段距离中,每个位置被走过的概率就是$frac{k}{frac{k(k+1)}{2}}=frac{2}{k+1}$

    在其他时候,直接套上面的板子即可

    牛客的玄学评测机,同一份代码能差出500ms = =

    #include <cstdio>
    #include <cstring>
    using namespace std;
     
    typedef long long ll;
    const int MOD=1000000007;
    const int N=1100;
     
    inline int quickpow(int x,int t)
    {
        int res=1;
        while(t)
        {
            if(t&1)
                res=ll(res)*x%MOD;
            x=ll(x)*x%MOD;
            t>>=1;
        }
        return res;
    }
     
    inline int rev(int x)
    {
        return quickpow(x,MOD-2);
    }
     
    int n,rn;
    ll m;
    int a[N];
    int f[N<<1];
     
    int tmp[N<<1];
     
    void mul(int *y,int *x)
    {
        memset(tmp,0,sizeof(tmp));
         
        for(int i=0;i<n;i++)
            for(int j=0;j<n;j++)
                tmp[i+j]=(tmp[i+j]+ll(y[i])*x[j])%MOD;
         
        for(int i=0;i<n-1;i++)
            for(int j=0;j<n;j++)
                tmp[i+j+1]=(tmp[i+j+1]+ll(tmp[i])*a[j])%MOD;
         
        for(int i=0;i<n;i++)
            y[i]=tmp[i+n-1];
    }
     
    int w[N<<1],x[N<<1];
     
    int BM()
    {
        if(m<(n<<1))
            return f[m];
         
        for(int i=0;i<n;i++)
            x[i]=a[i],w[i]=a[i];
         
        ll t=(m-n)/n;
        int rem=m-n-t*n;
         
        while(t)
        {
            if(t&1)
                mul(w,x);
            mul(x,x);
            t>>=1;
        }
         
        int res=0;
        for(int i=0;i<n;i++)
            res=(res+ll(w[i])*f[rem+n-i-1])%MOD;
        return res;
    }
     
    int main()
    {
        int T;
        scanf("%d",&T);
        while(T--)
        {
            scanf("%d%lld",&n,&m);
             
            if(m==-1)
            {
                printf("%d
    ",2LL*rev(n+1)%MOD);
                continue;
            }
             
            rn=rev(n);
            for(int i=0;i<n;i++)
                a[i]=rn;
             
            memset(f,0,sizeof(f));
            f[0]=1;
            for(int i=1;i<(n<<1);i++)
                for(int j=1;j<=n && j<=i;j++)
                    f[i]=(f[i]+ll(rn)*f[i-j])%MOD;
             
            printf("%d
    ",BM());
        }
         
        return 0;
    }
    View Code

    比较特定的知识点吧,以后遇到就是赚到(然后发现强制NTT,直接白给= =)

    (完)

  • 相关阅读:
    Block的强强引用问题(循环引用)
    自己封装的下载方法
    MJRefresh上拉刷新下拉加载
    JavaScript 模块的循环加载
    webpack使用require注意事项
    console.log高级用法
    path.resolve()和path.join()的区别
    深入理解react
    react children技巧总结
    揭秘css
  • 原文地址:https://www.cnblogs.com/LiuRunky/p/Berlekamp_Massey.html
Copyright © 2011-2022 走看看