zoukankan      html  css  js  c++  java
  • 【模板】常系数线性递推

    问题描述

    求一个满足 $K$ 阶齐次线性递推数列 $a_i$ 的第 $n$ 项,即:$a_n = sum_{i=1}^k f_i imes a_{n-i}$.

    分析

      首先写成矩阵快速幂

    $$left( egin{bmatrix} f_1 &f_2 &f_3 &f_4 & cdots &f_{k-2} &f_{k-1} \ 1 &0 &0 &0 & cdots &0 &0 \ 0 &1 &0 &0 & cdots &0 &0\ cdots & cdots& cdots & cdots & cdots & cdots & cdots\ 0 &0 &0 &0 & cdots &1 &0 end{bmatrix} ight) ^n imes egin{bmatrix} a_{k-1} \ a_{k-2} \ cdots \ a_{1} \ a_{0}end{bmatrix} =egin{bmatrix} a_{n+k-1} \ a_{n+k-2} \ cdots \ a_{n+1} \ a_{n}end{bmatrix}$$

    所以我们只需要算出 $M^N imes A$,然后取最后一个数即可。

    使用矩阵快速幂,复杂度 $O(k^3 log_2n)$.

      Carlay-Hamilton定理

    设有 $k$ 个特征值的矩阵 $A$的特征多项式为 $f(lambda ) =prod_{i=1}^k(lambda_i - x)$,则有 $f(A) = 0$,$0$  为零矩阵。

      用这个定理来优化递推

    由前面的矩阵快速幂,我们只要求出 $M^n$就可以了。

    我们考虑 $M$ 的特征多项式 $f(x)$,这是一个 $k$ 次多项式。我们对 $M^n$ 做带余除法 $M^n = f(M) imes g(M) + R(M)$。

    由于 $f(M) = 0$,所以 $M^n equiv  R(M) (mod f(M))$,$R(M)$ 是一个次数不超过 $k-1$ 的多项式。

    也就是说,我们只要求出 $M^n \% f(M)$就可以了

    但是要怎么求呢?我们考虑快速幂的过程(就是倍增)

    假设我们现在已知 $g(M)=M^{2^i} \% f(M)$,现在要求  $h(M)= M^{2^{i+1}} \% f(M)$。

    一个直接的想法是令 $H(M)=g(M) imes g(M)$。但是这样做 $H(x)$ 的次数是 $2k-2$次的。

    那么我们考虑原本的递推关系,$a_n=sumlimits_{i=1}^{k}a_{n-i}*f_i$

    不难得到 $M^n=sumlimits_{i=1} ^{k} M^{n-i} imes f_{i}$

    所以我们可以用这个式子将多余的系数都向前压一位。

    这样我们可以得到一个 $O(k^2  log_2 n)$ 的做法。

    那么有没有优化的余地呢?我们从倍增的过程入手,可以发现 $H(M) = g(M) imes g(M)$ 的过程可以用FFT/NTT加速至 $O(k log_2k)$。

     现在只要解决压系数就可以了,把 $H(M)$ 模 $f(M)$ 即可。

    我们的推导一直用到这个特征多项式 $f(x)$,如何求得呢?

    根据定义, $f(x) = det(xI - M)$,得到

    $$f(x) = |x I - M| = egin{bmatrix} x- a_1 & -a_2 & -a_3 & cdots & -a_{k - 2} & -a_{k - 1} & -a_k \ -1 & x & 0 & cdots & 0 & 0 & 0 \ 0 & -1 & x & cdots & 0 & 0 & 0 \ 0 & 0 & -1 & cdots & 0 & 0 & 0 \ vdots & vdots & vdots & ddots & vdots & vdots & vdots \ 0 & 0 & 0 & cdots & -1 & x & 0 \ 0 & 0& 0 & cdots & 0 & -1 & xend{bmatrix}$$

    对第一行进行展开,得到

    $$f(x) = (x - a_1)M_{11} + (-a_2)M_{12} + cdots + (-a_k)M_{1n} = x ^ k - a_1 x ^ {k - 1} - a_2x ^ {k - 2} - cdots - a_k$$

    代码1:

    $O(k log_2k log_2n)$的做法

    思路其实就是去做一个类似快速幂的操作,然后把乘法改成多项式下的,取模也改成多项式下的

    // luogu-judger-enable-o2
    #include<cstdio>
    #include<algorithm>
    using namespace std;
    
    typedef long long ll;
    const ll mod=998244353;
    const int N=65536+10;
    int n;int k;int rv[20][N];ll rt[20][20];int Len;ll tr1[N];ll tr2[N];long long st[N];long long xs[N];
    ll sg[N];ll a[N];ll res[N];ll irg[N];ll q[N];ll rf[N];int DL=-1;ll ans=0;ll ret[N];
    inline ll po(ll a,ll p){ll r=1;for(;p;p>>=1,a=a*a%mod)if(p&1)r=r*a%mod;return r;}
    inline void ntt(ll* a,int o,int len,int d)//ntt
    {
        for(int i=0;i<len;i++)if(i<rv[d][i])swap(a[i],a[rv[d][i]]);
        for(int k=1,j=1;k<len;k<<=1,j++)
            for(int s=0;s<len;s+=(k<<1))
                for(int i=s,w=1;i<s+k;i++,w=w*rt[o][j]%mod)
                {ll a0=a[i];ll a1=a[i+k]*w%mod;a[i]=(a0+a1)%mod,a[i+k]=(a0+mod-a1)%mod;}
        if(o==1){ll inv=po(len,mod-2);for(int i=0;i<len;i++)(a[i]*=inv)%=mod;}
    }
    inline void poly_inv(ll* a,ll* b,int len)//求逆
    {
        b[0]=po(a[0],mod-2);
        for(int k=1,j=0;k<=len;k<<=1,j++)
        {
            for(int i=0;i<k;i++)tr1[i]=a[i];for(int i=0;i<k;i++)tr2[i]=b[i];
            ntt(tr1,0,k<<1,j);ntt(tr2,0,k<<1,j);
            for(int i=0;i<(k<<1);i++)b[i]=tr2[i]*(2+mod-tr1[i]*tr2[i]%mod)%mod;
            ntt(b,1,k<<1,j);for(int i=k;i<(k<<1);i++)b[i]=0;
        }
    }
    inline void poly_mod(ll* a)//取模
    {
        int mi=(k<<1);while(a[--mi]==0);if(mi<k)return;
        for(int i=0;i<(Len<<1);i++)rf[i]=0;for(int i=0;i<=mi;i++)rf[i]=a[i];
        reverse(rf,rf+mi+1);for(int i=mi-k+1;i<=mi;i++)rf[i]=0;ntt(rf,0,Len<<1,DL+1);
        for(int i=0;i<(Len<<1);i++)q[i]=(rf[i]*irg[i])%mod;ntt(q,1,(Len<<1),DL+1);
        for(int i=mi-k+1;i<=(Len<<1);i++)q[i]=0;reverse(q,q+mi-k+1);ntt(q,0,(Len<<1),DL+1);
        for(int i=0;i<(Len<<1);i++)(q[i]*=sg[i])%=mod;ntt(q,1,(Len<<1),DL+1);
        for(int i=0;i<k;i++)(a[i]+=mod-q[i])%=mod;for(int i=k;i<=mi;i++)a[i]=0;
    }
    int main()
    {
        for(int i=0;i<=15;i++)
            for(int j=0;j<(1<<(i+1));j++)rv[i][j]=(rv[i][j>>1]>>1)|((j&1)<<i);
        for(int t=2,j=1;j<=18;t<<=1,j++)rt[0][j]=po(3,(mod-1)/t);
        for(int t=2,j=1;j<=18;t<<=1,j++)rt[1][j]=po(332748118,(mod-1)/t);
        scanf("%d%d",&n,&k);
        for(Len=1;Len<=k;Len<<=1,DL++); //预处理
        for(int i=1;i<=k;i++){scanf("%lld",&xs[i]);xs[i]=xs[i]<0?xs[i]+mod:xs[i];}
        for(int i=0;i<k;i++){scanf("%lld",&st[i]);st[i]=st[i]<0?st[i]+mod:st[i];}
        for(int i=1;i<=k;i++)sg[k-i]=mod-xs[i];sg[k]=1;for(int i=0;i<=k;i++)ret[i]=sg[i];
        for(int i=0;i<=k;i++)rf[i]=sg[i];reverse(rf,rf+k+1);poly_inv(rf,irg,Len);
        for(int i=0;i<=k;i++)rf[i]=0;ntt(sg,0,Len<<1,DL+1);ntt(irg,0,Len<<1,DL+1);a[1]=1;res[0]=1;
        while(n)//快速幂
        {
            if(n&1)
            {
                ntt(res,0,Len<<1,DL+1);ntt(a,0,Len<<1,DL+1);
                for(int i=0;i<(Len<<1);i++)(res[i]*=a[i])%=mod;
                ntt(res,1,Len<<1,DL+1);ntt(a,1,Len<<1,DL+1);poly_mod(res);
            }ntt(a,0,Len<<1,DL+1);for(int i=0;i<(Len<<1);i++)(a[i]*=a[i])%=mod;
            ntt(a,1,Len<<1,DL+1);poly_mod(a);n>>=1;
        }
        for(int i=0;i<k;i++)(ans+=res[i]*st[i])%=mod;
        printf("%lld",ans);
        return 0;
    }

     代码2:

    $O(k^2 log_2n)$的做法

     BZOJ4161

    #include<cstdio>
    #include<cstring>
    #include<cstdlib>
    #include<cctype>
    #include<cmath>
    #include<iostream>
    #include<algorithm>
    #include<vector>
    #include<set>
    #include<map>
    #include<queue>
    #include<stack>
    #include<cassert>
    typedef long long ll;
    typedef unsigned long long ull;
    using namespace std;
    
    const int P=1000000007;
    const int MAXN=4010;    //2*k+10
    
    int n,k,ans;
    int f[MAXN],h[MAXN];
    
    struct Matrix{ //其实是多项式
        int a[MAXN];
        Matrix (){memset(a,0,sizeof a);}
        int& operator [] (const int &i) {return a[i];}
        int operator [] (const int &i) const {return a[i];}
        inline Matrix operator * (const Matrix &rhs) const
        {
            Matrix ret;
            for(int i=0;i<k;i++)
                for(int j=0;j<k;j++)
                    (ret[i+j]+=1ll*a[i]*rhs[j]%P)%=P;
            for(int i=2*k-2;i>=k;ret[i--]=0)
                for(int j=1;j<=k;j++) //这里就是多项式取模优化的地方
                    (ret[i-j]+=1ll*ret[i]*f[j]%P)%=P; //可以认为是暴力向前压系数
            return ret;
        }
    }res;
    
    Matrix ksm(Matrix a,int b)
    {
        Matrix ret;
        ret[0]=1;
        for(;b;a=a*a,b>>=1) if(b&1) ret=ret*a;
        return ret;
    }
    
    int main()
    {
        scanf("%d%d",&n,&k);
        for(int i=1;i<=k;i++) scanf("%d",&f[i]),f[i]=f[i]>0?f[i]:f[i]+P;
        for(int i=0;i<k;i++) scanf("%d",&h[i]),h[i]=h[i]>0?h[i]:h[i]+P;
        if(n<k) printf("%d
    ",h[n]);
        res[1]=1;ans=0;
        res=ksm(res,n);
        for(int i=0;i<k;i++)  ans=(ans+1ll*res[i]*h[i]%P)%P;
        printf("%d
    ",ans);
    }

    参考链接:

    1. https://www.luogu.org/problemnew/solution/P4723

    2. https://www.luogu.org/blog/Zhang-RQ/chang-ji-shuo-ji-ci-xian-xing-di-tui-chu-tan

  • 相关阅读:
    bzoj4598: [Sdoi2016]模式字符串
    bzoj3156: 防御准备
    bzoj1966: [Ahoi2005]VIRUS 病毒检测
    bzoj3170: [Tjoi2013]松鼠聚会
    bzoj3171: [Tjoi2013]循环格
    POJ1068Parencodings
    2013年山东省第四届ACM大学生程序设计竞赛 Alice and Bob
    POJ2632Crashing Robots
    POJ1328Radar Installation
    POJ2586Y2K Accounting Bug
  • 原文地址:https://www.cnblogs.com/lfri/p/11236711.html
Copyright © 2011-2022 走看看