zoukankan      html  css  js  c++  java
  • 常系数齐次线性递推

    常系数齐次线性递推

    定义

    对于一个递推式,如果 (a_n = displaystyle sum_{i=1}^{k}{a_{n-i}*f_i}) ,那么称这个 (a) 序列满足 (n) 阶常系数齐次线性递推关系

    矩阵优化

    如果我们已知一个满足 (k) 阶常系数齐次线性递推关系的序列 (a) ,关系式为 (a_n = displaystyle sum_{i=1}^{k}{a_{n-i} * f_i}) ,要求出 (a_n) 的值

    可以设计出一个转移矩阵进行矩阵优化

    如果初始阵为

    [A= egin{pmatrix} a_{n-1}\ a_{n-2}\ vdots\ a_{n-k} end{pmatrix} ]

    转移阵为

    [M= egin{pmatrix} f_1 quad &f_2 quad &f_3 quad &dots quad &f_{k-1}\ 1 quad &0 quad &0 quad &dots quad &0\ 0 quad &1 quad &0 quad &dots quad &0\ vdots quad &vdots quad &vdots quad &ddots quad &vdots \ 0 quad &0 quad &0 quad &dots &1 end{pmatrix} ]

    那么 (M imes A) 可以得到矩阵

    [egin{pmatrix} a_{n}\ a_{n-1}\ vdots\ a_{n-k-1} end{pmatrix} ]

    那么我们可以设计初始矩阵为

    [A= egin{pmatrix} a_{k-1}\ a_{k-2}\ vdots\ a_{0} end{pmatrix} ]

    此时我们可以用 (M^n imes A) 来得到我们需要的矩阵

    特征多项式

    • 若有常数 (lambda) ,向量 (vec{v}) ,满足 (lambda vec{v} = A vec{v}) ,那么我们称 (lambda) 为矩阵 (A) 的特征值,称 (vec{v}) 为矩阵的特征向量

    那么我们可以得到 ((lambda I - A) vec{v}= 0) ,其中 (0) 表示零矩阵

    此时该式有解当且仅当 (det(lambda I - A) = 0)

    这个行列式的展开形式为一个 (k) 次多项式,此时,我们称这个 (k) 次多项式为 (A) 的特征多项式,该多项式的值为 (0) 时的方程称为 (A) 的特征方程

    记特征多项式为 (f(x) = det(lambda I - A)) ,那么可以表示为 (f(x) = displaystyle prod_{i}{lambda_i - x})

    凯莱-哈密顿定理 (Cayley-Hamilton定理)

    • 对于 (A) 的特征多项式 (f(x)) ,有 (f(A) = 0)

    证明

    (f(A) =displaystyle prod_{i}{lambda_i I - A})

    对于这个 (k) 次的特征多项式,其有 (k) 个解,也就是说矩阵 (A)(k) 个特征值以及 (k) 个线性无关的特征向量,而如果 (f(A)) 得到的矩阵乘上任意一个特征向量都可以得到零矩阵,那么就可以推出 (f(A)) 为零矩阵

    首先,可以证明, ((lambda_i I - A)(lambda_j I - A) = (lambda_j I - A)(lambda_i I - A))

    那么

    [egin{aligned} f(A) imes vec{v_i} &= (displaystyle prod_{j}{lambda_j I - A}) imes vec{v_i} \ &= (displaystyle prod_{j eq i}{lambda_j I - A}) imes ((lambda_i I - A) imes vec{v_i}) end{aligned} ]

    由特征值与特征向量的定义式可知: ((lambda_i I - A) vec{v_i} = 0)

    所以 (forall f(A) imes vec{v_i} =0)

    得证

    常系数齐次线性递推优化

    设矩阵 (M) 的特征多项式为 (f(x))

    对于我们要求的 (M^n) ,可以写出

    [M^n = f(M) imes g(M) + R(M) ]

    (f(M)=0) ,那么就有 (M^n = R(M))

    所以,我们只需要做 (M^n ~\% ~f(M)) 就可以了

    考虑 (f(M)) 怎么求

    按照定义 (f(x) = det(x I - M)) ,所以这里有

    [f(x)= egin{vmatrix} x-a_1 quad &-a_2 quad &-a_3 quad &dots quad &-a_{k-1}quad &-a_{k}\ -1 quad &x quad &0 quad &dots quad &0 quad &0\ 0 quad &-1 quad &x quad &dots quad &0 quad &0\ vdots quad &vdots quad &vdots quad &ddots quad &vdots quad &vdots\ 0 quad &0 quad &0 quad &dots &-1 quad &x end{vmatrix} ]

    将其进行展开,有

    [egin{aligned} f(x) &= displaystyle (x-a_1)M_{11} - a_2 M_{12} dotsb - a_k M_{1k}\ &= x^k - a_1 x^{k-1} - a_2 x^{k-2} - dotsb a_k end{aligned} ]

    处理 (M^n ~\%~ f(M)) 我们可以在做快速幂的时候进行实现,所以这里的实现只需要在快速幂的时候做多项式取模即可,复杂度为 (O(k^2 log n))

    而这里我们做快速幂的时候还会涉及多项式乘法,那么可以进行 NTTFFT 优化,做到 (O(k log k log n))

    那么我们这里已经快速处理出了 (M^n) ,之后直接和初始的矩阵 (A) 相乘即可求得答案

    例题

    首先,恰好 (K) 个的概率不容易处理,可以考虑将其处理为至少有 (K) 个的概率减去至少有 (K-1) 个的概率

    (f_i) 表示在底部的一个宽为 (i) 的矩形,并且第 (i) 个位置恰好为不合法的位置

    那么最终答案就是 (frac{f_{n+1}}{1-q})

    这里有 (f_n = displaystyle sum_{i=1}^{n}{f_{n-i+1} * g_i}) ,这里 (g_i) 表示出现长度宽为 (i) 的矩形的概率

    (dp_{i,j}) 表示一个宽为 (i) ,高位 (j) 的矩形

    那么这里 (g_i = displaystyle sum_{j=1}^{infty}{dp_{i,j}})

    而这个 (dp_{i,j}) 实际上也是可以递推的,有递推式为

    [dp_{i,j} = [i*(j-1) leq K] (1-q) q^{j-1} displaystyle sum_{k=1}^{i}{(displaystyle sum_{q > j} dp_{k-1,q})(displaystyle sum_{q geq j}{dp_{i-k,q}})} ]

    表示 (dp_{i,j}) 可以由宽 (k-1) 中那些高度大于 (j) 的矩形的情况在和宽 (i-k) ,高大于等于 (j) 的那些矩形拼起来,再乘上当前宽度为 (i) 的这个地方的高度只有 (j) 的部分的概率

    这样,这里 (i imes (j-1) leq K) ,所以对 (i,j) 的枚举的复杂度为 (O(K log K)) ,再加上枚举 (k,q) 的枚举,复杂度为 (O(K^2 log^2 K))

    而这个式子中对 (q) 枚举的部分是可以后缀和优化的(并且在 (f) 的求解中应用),那么此时求 (dp) 数组的复杂度可以被优化到 (O(k^2 log k))

    (f) 数组的求解显然满足常系数其次线性递推的形式,可以直接套用优化,那么总复杂度为 (O(k^2 log k)) (完全没有必要用 FFTNTT 优化,直接暴力做多项式取模即可)

    #include<iostream>
    #include<cstdio>
    #include<algorithm>
    #include<math.h>
    #include<vector>
    #include<queue>
    #include<cstring>
    #define ll long long
    #define ld long double
    
    inline ll read()
    {
    	ll x=0,f=1;
    	char ch=getchar();
    	while(!isdigit(ch))
    	{
    		if(ch=='-') f=-1;
    		ch=getchar();
    	}
    	while(isdigit(ch))
    	{
    		x=(x<<1)+(x<<3)+ch-'0';
    		ch=getchar();
    	}
    	return x*f;
    }
    
    const ll inf=1e18;
    const ll maxn=2e3+10;
    const ll mod=998244353;
    ll N,K,X,Y,p,q;
    ll pw[maxn];
    ll dp[maxn][maxn],sum[maxn][maxn];
    ll I[maxn],A[maxn],M[maxn],ret[maxn],f[maxn];
    ll tmp1[maxn],tmp2[maxn];
    
    inline ll ksm(ll a,ll b,ll p)
    {
    	ll ret=1;
    	while(b)
    	{
    		if(b&1) ret=ret*a%p;
    		a=a*a%p;
    		b>>=1;
    	}
    	return ret;
    }
    
    inline ll sol(ll x)
    {
    	ll ans=0;
    	memset(M,0,sizeof(M));
    	memset(A,0,sizeof(A));
    	memset(f,0,sizeof(f));
    	memset(I,0,sizeof(I));
    	memset(dp,0,sizeof(dp));
    	memset(sum,0,sizeof(sum));
    	memset(ret,0,sizeof(ret));
    	for(int i=0;i<=x+2;i++) sum[0][i]=dp[0][i]=1;
    	for(int j=x;j>=1;j--)
    	{
    		for(int i=1;i*j<=x;i++)
    		{
    			for(int k=1;k<=i;k++)
    			{
    				(dp[i][j]+=sum[k-1][j+1]*sum[i-k][j]%mod*p%mod*pw[j]%mod)%=mod;
    			}
    			sum[i][j]=(sum[i][j+1]+dp[i][j])%mod;
    		}
    	}
    //	for(int j=1;j<=x;j++)
    //	{
    //		for(int i=1;i*j<=x;i++)
    //		{
    //			printf("%d %d %lld %lld
    ",i,j,dp[i][j],sum[i][j]);
    //		}
    //	}
    	x++;
    	for(int i=1;i<=x;i++) I[i]=sum[i-1][1]*p%mod;
    	A[0]=1;
    	for(int i=1;i<=x;i++)
    	{
    		for(int j=0;j<i;j++)
    		{
    			(A[i]+=A[j]*I[i-j]%mod)%=mod;
    		}
    	}
    	for(int i=1;i<=x;i++) f[x-i]=mod-I[i];
    	f[x]=1;
    //	for(int i=0;i<=x;i++) printf("%lld ",I[i]);
    //	putchar(10);
    //	for(int i=0;i<=x;i++) printf("%lld ",A[i]);
    //	putchar(10);
    //	for(int i=0;i<=x;i++) printf("%lld ",f[i]);
    //	putchar(10);
    	ret[0]=1;
    	M[1]=1;
    	ll b=N+1;
    	while(b)
    	{
    		if(b&1)
    		{
    			memcpy(tmp1,ret,sizeof(ret));
    			memset(ret,0,sizeof(ret));
    			for(int i=0;i<=x;i++)
    			{
    				for(int j=0;j<=x;j++)
    				{
    					(ret[i+j]+=M[i]*tmp1[j])%=mod;
    				}
    			}
    			for(int i=2*x;i>=x;i--)
    			{
    				for(int j=0;j<=x;j++)
    				{
    					(ret[i+j-x]+=mod-ret[i]*f[j]%mod)%=mod;
    				}
    			}
    		}
    		memcpy(tmp1,M,sizeof(M));
    		memcpy(tmp2,M,sizeof(M));
    		memset(M,0,sizeof(M));
    		for(int i=0;i<=x;i++)
    		{
    			for(int j=0;j<=x;j++)
    			{
    				(M[i+j]+=tmp1[i]*tmp2[j]%mod)%=mod;
    			}
    		}
    		for(int i=2*x;i>=x;i--)
    		{
    			for(int j=0;j<=x;j++)
    			{
    				(M[i+j-x]+=mod-M[i]*f[j])%=mod;
    			}
    		}
    		b>>=1;
    	}
    	for(int i=0;i<=x;i++) (ans+=ret[i]*A[i])%=mod;
    //	printf("%lld
    ",ans);
    	return ans*ksm(p,mod-2,mod)%mod;
    }
    
    int main(void)
    {
    //	freopen("1.in","r",stdin);
    //	freopen("1.ans","w",stdout);
    	N=read(),K=read(),X=read(),Y=read();
    	q=X*ksm(Y,mod-2,mod)%mod;
    	p=(1-q+mod)%mod;
    	pw[0]=1;
    	for(int i=1;i<=K;i++) pw[i]=pw[i-1]*q%mod;
    	printf("%lld
    ",(sol(K)-sol(K-1)+mod)%mod);
    	return 0;
    }
    
  • 相关阅读:
    智能汽车无人驾驶资料调研(一)
    Python 学习
    关于中英文排版的学习
    UI Testing
    项目管理:第一次参与项目管理
    自动化测试用什么语言好
    什么是自动化测试
    睡眠的重要性
    python的pip和cmd常用命令
    矩阵的切片计算(截取)
  • 原文地址:https://www.cnblogs.com/jd1412/p/15220789.html
Copyright © 2011-2022 走看看