zoukankan      html  css  js  c++  java
  • 【学习笔记】 常系数齐次线性递推

    问题描述

    给定一个 (k) 阶常系数齐次线性递推数列的前 (k)(h_1, h_2, h_3..h_k) 和线性递推式 (h_n = sum_{i = 1} ^ k a_i h_{n - i}), 求这个数列的第 (n) 项。

    复杂度要求: (O(k^2logn))

    加强: (O(klogklogn))

    前置知识

    矩阵的特征值

    (A)(n) 阶方阵,如果存在数 (lambda) 和非零 (n) 维列向量 (x),使得 (Ax=lambda x) 成立,则称 (lambda) 是矩阵 (A) 的一个特征值。

    矩阵的特征多项式

    (E) 为单位矩阵,(n) 阶方阵 (A) 的特征多项式为 (|lambda E - A|)

    解释一下:(lambda) 为一个变量。 (lambda E - A) 的行列式是一个关于 (lambda)(n) 次多项式,即 (A) 的特征多项式。

    特征值是 (|lambda E - A| = 0) 的根。

    Cayley-Hamilton定理

    (A) 的特征多项式为 (p_A)(O) 为零矩阵。用矩阵 (A) 代替 (lambda) 带入 特征多项式,有 (p_A(A) = O)

    简要证明一下:把 (A) 直接带进特征多项式, (P_A(A) = |AE - A| = 0)

    其实我并不会证这个东西。自学的线代等于没学

    算法流程

    现在我们已经熟背了Cayley-Hamilton定理。

    考虑计算一下矩阵快速幂的转移矩阵 (M) 的特征多项式。

    (M_{x,y}) 为去掉位置 ((x,y)) 的代数余子式,(m_{x,y})((x,y)) 位置上的元素,直接对第 (1) 行拉普拉斯展开,(M = sum_{i=1}^k m_{1,i} M_{1,i}) ,可以发现去掉第一行第 (i) 列之后留下的是一个下三角矩阵,(M_{1,i} = (-1) ^ {i + 1} (-1)^{i-1}lambda^{k - i}),(m_{1,1} =lambda-a_1,m_{1,i} = -a_i),整理一下就得到:

    [|Elambda - M| = lambda ^ k - a_1 lambda ^ {k - 1} - a_2 lambda ^ {k - 2} - ... - a_n ]

    由 Cayley-Hamilton定理,我们得到 (p_M(M) = O)

    现在我们用 (n) 代替 (n - k),我们要求出 (M^n)

    (M^n= p_M(M) A(M) + r(M)),其中(r(M)) 的次数不高于(k - 1)

    因为 (p_M(M) = O), 得到 (M^n = r(M))

    我们现在只需要求 (M^n mod p_M(M))

    因为 (AB mod C = (A mod C)(B mod C) mod C),所以可以快速幂计算 (M^n mod p_M(M))

    直接暴力复杂度 (O(k^2logn)), 用多项式乘法和取模可以做到 (O(klogklogn))

    现在得到了 (M^n = r(M) = sum_{i = 0} ^ {k - 1} c_i M ^ i)

    再分析一下就可以得到 (ans = sum_{i = 0} ^ {k - 1} c_i h_{i + k})

    再暴力或者 NTT 处理一下前 (h) 的前 (2k) 项就好了。

    需要注意,当 (k = 1) 时,(M^1) 需要对 (p_M(M)) 取模。

    模板

    BZOJ4161: Shlw loves matrixI

    这道题里面 (h) 的下标是从 (0) 开始的。

    #pragma GCC optimize("2,Ofast,inline")
    #include<bits/stdc++.h>
    #define fi first
    #define se second
    #define mp make_pair
    #define pb push_back
    #define LL long long
    #define pii pair<int, int>
    using namespace std;
    const int mod = 1e9 + 7;
    
    template <typename T> T read(T &x) {
    	int f = 0;
    	register char c = getchar();
    	while (c > '9' || c < '0') f |= (c == '-'), c = getchar();
    	for (x = 0; c >= '0' && c <= '9'; c = getchar())
    		x = (x << 3) + (x << 1) + (c ^ 48);
    	if (f) x = -x;
    	return x;
    }
    
    inline void upd(int &x, int y) {
    	(x += y) >= mod ? x -= mod : 0;
    }
    
    inline int add(int x, int y) {
    	return (x += y) >= mod ? x - mod : x;
    }
    
    inline int dec(int x, int y) {
    	return (x -= y) < 0 ? x + mod : x;
    }
    
    namespace Linear {
    	static const int Maxn = 5005;
        
    	int n, k;
    	int a[Maxn], h[Maxn];
    	int b[Maxn], c[Maxn], p[Maxn];
    
    	void module(int *x) {
    		for (int i = k * 2; i >= k; --i) {
    			int tmp = x[i];
    			for (int j = 0; j <= k; ++j) {
    				x[i - j] = dec(x[i - j], 1LL * p[k - j] * tmp % mod);
    			}
    		}
    	}
    	
    	void mul(int *x, int *y, int *z) {
    		static int res[Maxn];
    		for (int i = 0; i <= k * 2; ++i) res[i] = 0;
    		for (int i = 0; i < k; ++i) {
    			for (int j = 0; j < k; ++j) {
    				upd(res[i + j], 1LL * x[i] * y[j] % mod);       
    			}
    		}
    		module(res);
    		for (int i = 0; i < k; ++i) z[i] = res[i];
    	}
    
    	void poly_pow(int p) {
    		while (p) {
    			if (p & 1) mul(b, c, c);
    			p >>= 1;
    			if (!p) break;
    			mul(b, b, b);
    		}
    	}
        
    	int solve() {
    		if (n <= k) return h[n];
    		p[k] = 1;
    		for (int i = 0; i < k; ++i)
    			p[i] = dec(0, a[k - i]);
    		b[1] = 1; c[0] = 1;
    		if (k == 1) module(b);
    		poly_pow(n - k);
    		for (int i = k + 1; i <= k * 2; ++i) {
    			for (int j = 1; j <= k; ++j) {
    				upd(h[i], 1LL * h[i - j] * a[j] % mod);
    			}
    		}
    		int ans = 0;
    		for (int i = 0; i < k; ++i)
    			upd(ans, 1LL * c[i] * h[i + k] % mod);
    		return ans;
    	}
    }
    using namespace Linear;
    
    int main() {
    	read(n); read(k); ++n;
    	for (int i = 1; i <= k; ++i) {
    		read(a[i]);
    		if (a[i] < 0) a[i] += mod;
    	}
    	for (int i = 1; i <= k; ++i) {
    		read(h[i]);
    		if (h[i] < 0) h[i] += mod;
    	}
    	cout << solve() << endl;
    	return 0;
    }
    
    
  • 相关阅读:
    Mysql备份和恢复
    前端Css学习
    jQuery学习
    HTML页面学习
    Linux下java环境变量配置
    windows下java环境变量标准配置
    oracle查询消耗服务器资源SQL语句
    Java主线程在子线程执行完毕后再执行
    CentOS7 安装 Redis
    查看Oracle表空间使用情况
  • 原文地址:https://www.cnblogs.com/Vexoben/p/11841443.html
Copyright © 2011-2022 走看看