原题链接
- 题意:给出矩阵 (A),为 (N imes M) 的矩阵,矩阵 (B) 为 (M imes N) 的矩阵,(4 <= N <= 1000, 1 <= M <= 10) ,设矩阵 (C = A imes B) 求出 (C^n) 各个元素和。
- 题解:可以发现的是乘一步 (O(n^3)) 显然吃不消。但是,可以发现一个性质,即 (M) 比较小。((A imes B) imes(A imes B) = A imes(B imes A) imes B) 可以发现,如果 (n) 个 (A imes B) 相乘,可以转化成 (n-1) 个 ((B imes A)) 相乘,设结果为矩阵 (T),那么最终结果即为 (A imes T imes B), 而求 (B imes A) 是 (N imes M) 和 (M imes N) 是 (O(m^2n)) 的复杂度,可以接受。算 (T) 的复杂度是可用快速幂,即 (O(m^2n imes log n^2)),然后最终 (A imes T) 是 (O(m^2n)) 的复杂度,然后再乘 (B) 是 (O(n^2m)) 的复杂度。
最终复杂的度是 (O(m^2n imes log n^2 + n^2m))。
- 代码:
#include <iostream>
#include <cstring>
using namespace std;
typedef long long ll;
const int N = 1e3 + 9;
const ll mod = 6;
int n, m;
ll A[N][N], B[N][N], T[N][N],ans[N][N];
struct Matrix {
ll a[10][10];
Matrix(){memset(a, 0, sizeof a);}
Matrix operator*(Matrix rhs)const {
Matrix ret;
for (int i = 1; i <= m; i ++) {
for (int j = 1; j <= m; j++){
for (int k = 1; k <= m; k++) {
(ret.a[i][j] += (a[i][k] * rhs.a[k][j] % mod)) %= mod;
}
}
}
return ret;
}
void pr() {
for (int i = 1; i <= m; i++) {
for (int j = 1; j <= m; j++) {
cout << a[i][j] << " ";
}
cout << endl;
}
}
};
Matrix ksm (Matrix A, int kk) {
if (kk == 1)return A;
Matrix ret;
bool f = 0;
//cout << kk << "???";
while (kk) {
if (kk & 1) {
if (!f) {
ret = A;
f = 1;
// cout << "?";
} else
ret = ret * A;
}
kk >>= 1;
A = A * A;
}
return ret;
}
void solve() {
while (cin >> n >> m) {
if (n == 0 && m == 0)return;
Matrix C;
for (int i = 1; i <= n; i++) {
for (int j = 1; j <= m; j++) {
cin >> A[i][j];
}
}
for (int i = 1; i <= m; i++) {
for (int j = 1; j <= n; j++) {
cin >> B[i][j];
}
}
for (int i = 1; i <= m; i ++) {
for (int j = 1; j <= m; j ++) {
for (int k = 1; k <= n; k++) {
(C.a[i][j] += B[i][k] * A[k][j] % mod)%=mod;
}
}
}
int kk = n * n-1;
Matrix M = ksm(C, kk);
memset(T, 0, sizeof T);
for (int i = 1; i <= n; i++) {
for (int j = 1; j <= m; j++) {
for (int k = 1; k <= m; k++) {
(T[i][j] += A[i][k] * M.a[k][j] % mod) %= mod;
}
}
}
ll sum = 0;
memset(ans, 0, sizeof ans);
for (int i = 1; i <= n; i++) {
for (int j = 1; j <= n; j++) {
for (int k = 1; k <= m; k++) {
(ans[i][j] += T[i][k] * B[k][j] % mod) %= mod;
}
}
}
for (int i = 1; i <= n; i++) {
for (int j = 1; j <= n; j++) {
sum += ans[i][j];
}
}
cout << sum << endl;
}
}
signed main() {
int t = 1;//cin >> t;
while (t--) {
solve();
}
}