原文链接www.cnblogs.com/zhouzhendong/p/Cayley-Hamilton.html
Cayley-Hamilton定理与矩阵快速幂优化、常系数齐次线性递推优化
引入
在开始本文之前,我们先用一个例题作为引入。
- 给定一个 (n imes n) 的矩阵 (M) , 求 (M ^ k) 。
- (nleq 50, kleq 10 ^ {50000}) 。
注意到 (n) 十分小,但是 $ log k$ 非常大。如果使用传统的矩阵快速幂,时间复杂度为 (O(n ^ 3 log k )) ,难以接受。
但是运用 Cayley-Hamilton定理 来优化矩阵快速幂,可以做到 (O(n ^ 4+n^2log k)) 甚至更优秀的复杂度(^1)。
- cly 说他看到有人说可以优化到 (O(n^3 + n^2 log k)) 。但是不知道怎么优化。
Cayley-Hamilton定理
设 (M) 为一个 (n) 阶矩阵,定义矩阵 (M) 的特征多项式为
其中 (x) 可以属于一些域,包括但不限于复数域。
由于 (|ME - M| = |0| = 0) ,所以
矩阵快速幂
如果我们得到了 (f(M)) ,那么,将任意一个矩阵减去任意倍数的 (f(M)) 后值不变。
即:令 (g(x) = x^k),
考虑如何求解 (g mod f)。
类比对整数域下取模的快速幂做法,考虑对多项式 (x) 做快速幂,对 (f) 取模,直接实现的时间复杂度为 (O(n ^ 2 log k)),如果采用 (FFT) 优化乘法,并利用多项式取模的做法实现取模,那么时间复杂度为 (O(n log n log k)) 。
接下来我们来讨论如何求 (f(M)) 。
将 (n) 个值代入 (x) 求行列式,再插值得到 (f(M)) ,时间复杂度 $O(n ^ 4) $ ,注意这里要求代入的值存在乘法逆元。
得到 (g mod f) 之后只需要将所有 (M) 的幂代入即可得到 $M ^ k $,这里时间复杂度也是 $O(n ^ 4) $ 。
综上所述,总时间复杂度 (O(n ^ 4)) 。
线性递推
回归本源,当我们要做矩阵快速幂的原因,往往是为了快速实现线性递推。由于线性递推问题存在特殊性,我们可以通过 Cayley-Hamilton定理 来得到更优秀的做法。
假设线性递推数列满足
我们将线性递推的矩阵写出来:
假设初始向量是列向量 (B) ,矩阵是 (M) ,那么
则我们要求的是
接下来我们来求 (M) 的特征多项式。
于是我们就直接得到了特征多项式的系数。
于是
假设结果为
因为 (B M ^ i [a] = B[a + i]),而这里 (i < k),我们要求的是 (BM^i[0]) ,所以我们只需要知道 (B[0] cdots B[k-1]) 即可。类似地,只要我们预处理 B 数列的前 (2k) 项,就可以得到整个 (BM^i) 列向量了。
求单个值的时间复杂度为 (O(k ^ 2log n)) 。
模板题 BZOJ4161
代码
#include <bits/stdc++.h>
#define clr(x) memset(x,0,sizeof x)
#define For(i,a,b) for (int i=a;i<=b;i++)
#define Fod(i,b,a) for (int i=b;i>=a;i--)
#define fi first
#define se second
#define pb(x) push_back(x)
#define mp(x,y) make_pair(x,y)
#define outval(x) printf(#x" = %d
",x)
#define outtag(x) puts("---------------"#x"---------------")
#define outarr(a,L,R) printf(#a"[%d..%d] = ",L,R);
For(_x,L,R)printf("%d ",a[_x]);puts("")
using namespace std;
typedef long long LL;
LL read(){
LL x=0,f=0;
char ch=getchar();
while (!isdigit(ch))
f|=ch=='-',ch=getchar();
while (isdigit(ch))
x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
return f?-x:x;
}
const int N=2005*2,mod=1e9+7;
int n,k;
int a[N],b[N];
void Add(int &x,int y){
if ((x+=y)>=mod)
x-=mod;
}
void Del(int &x,int y){
if ((x-=y)<0)
x+=mod;
}
int c[N];
void Mul(int *x,int *y){
static int z[N];
clr(z);
For(i,0,k-1)
For(j,0,k-1)
Add(z[i+j],(LL)x[i]*y[j]%mod);
Fod(i,2*k-2,k){
if (!z[i])
continue;
For(j,1,k)
Add(z[i-j],(LL)a[j]*z[i]%mod);
}
For(i,0,k-1)
x[i]=z[i];
}
void GetPoly(){
static int x[N];
int y=n;
clr(x),clr(c),c[0]=x[1]=1;
for (;y;y>>=1,Mul(x,x))
if (y&1)
Mul(c,x);
}
int main(){
n=read(),k=read();
For(i,1,k)
a[i]=(read()+mod)%mod;
For(i,0,k-1)
b[i]=(read()+mod)%mod;
GetPoly();
int ans=0;
For(i,0,k-1)
Add(ans,(LL)b[i]*c[i]%mod);
cout<<ans<<endl;
return 0;
}