今天练了不少快速幂的手,这一直是之前的一个漏洞吧,现在把洞补上。东西是挺简单的东西,当然题目多变,做起来也问题多多。
首先放一下核心代码:
//矩阵乘法 const int mod = 10000; const int maxn = 2; struct matrix { int a[maxn][maxn]; }; matrix mul(matrix A, matrix B) { matrix ret; memset(ret.a, 0, sizeof(ret.a)); for(int i = 0; i < maxn; ++i) for(int k = 0; k < maxn; ++k) if(A.a[i][k]) //注意此处的优化,一般矩阵复杂度还是O(n^3),然而当矩阵是稀疏矩阵,即存在很多0时,复杂度则甚至可能降为O(n^2); for(int j = 0; j < maxn; ++j) { ret.a[i][j] += A.a[i][k] * B.a[k][j]; if(ret.a[i][j] >= mod) ret.a[i][j] %= mod; } return ret; }
//快速幂计算。二分原理: a^k = (a^2)^(k/2) = ((a^2)^2)^(k/4); matrix expo(matrix p, int k) { if(k == 1) return p; matrix ret; memset(ret.a, 0, sizeof(ret.a)); for(int i = 0; i < maxn; ++i) ret.a[i][i] = 1; if(k == 0) return ret; while(k) { if(k & 1) ret = mul(p, ret); p = mul(p, p); k >>= 1; } return ret; }
对于此处代码可以将幂转化成二进制形式来理解:例如当k=156时,156 = 10011100 = 128 + 16 + 8 + 4; ans = a156 = a128 * a16 * a8 * a4
从右向左每一位i(i >= 0)即ai,碰见一个1就把ai乘到ans里;
1 while(k) 2 { 3 if(k & 1) 4 ret = mul(p, ret); 5 p = mul(p, p); 6 k >>= 1; 7 }
结构体形式模板:
![](https://images.cnblogs.com/OutliningIndicators/ContractedBlock.gif)
1 #include <iostream> 2 #include <cstdio> 3 #include <cstring> 4 using namespace std; 5 typedef long long ll; 6 int n, k, mod; 7 const int maxn = 100; 8 struct matrix 9 { 10 int a[maxn][maxn]; 11 void print() 12 { 13 for(int i = 0; i < n; i++) 14 { 15 for(int j = 0; j < n; j++) 16 { 17 if(j) printf(" "); 18 printf("%d", a[i][j] % mod); 19 } 20 printf(" "); 21 } 22 } 23 matrix& operator += (const matrix& rhs) 24 { 25 for(int i = 0; i < n; ++i) 26 for(int j = 0; j < n; ++j) 27 if(rhs.a[i][j]) 28 { 29 a[i][j] += rhs.a[i][j]; 30 if(a[i][j] >= mod) a[i][j] %= mod; 31 } 32 return *this; 33 } 34 matrix& operator *= (const matrix& rhs) 35 { 36 matrix ret; 37 memset(ret.a, 0, sizeof(ret.a)); 38 for(int i = 0; i < n; ++i) 39 for(int k = 0; k < n; ++k) 40 if(a[i][k]) 41 for(int j = 0; j < n; ++j) 42 { 43 ret.a[i][j] += a[i][k] * rhs.a[k][j]; 44 if(ret.a[i][j] >= mod) 45 ret.a[i][j] %= mod; 46 } 47 memcpy(a, ret.a, sizeof(a)); 48 return *this; 49 } 50 }; 51 matrix expo(matrix p, int k) 52 { 53 if(k == 1) return p; 54 matrix ret; 55 memset(ret.a, 0, sizeof(ret.a)); 56 for(int i = 0; i < n; ++i) 57 ret.a[i][i] = 1; 58 if(k == 0) return ret; 59 while(k) 60 { 61 if(k & 1) 62 ret *= p; 63 p *= p; 64 k >>= 1; 65 } 66 return ret; 67 }
练手:
1、经典入门题:
![](https://images.cnblogs.com/OutliningIndicators/ContractedBlock.gif)
1 /* 2 Problem: Fibonacci 3 Tips: 矩阵快速幂 4 Date: 2015.8.1 5 */ 6 #include <iostream> 7 #include <cstdio> 8 #include <cstring> 9 using namespace std; 10 typedef long long ll; 11 const int mod = 10000; 12 const int maxn = 2; 13 struct matrix 14 { 15 int a[maxn][maxn]; 16 }; 17 matrix mul(matrix A, matrix B) 18 { 19 matrix ret; 20 memset(ret.a, 0, sizeof(ret.a)); 21 for(int i = 0; i < maxn; ++i) 22 for(int k = 0; k < maxn; ++k) 23 if(A.a[i][k]) 24 for(int j = 0; j < maxn; ++j) 25 { 26 ret.a[i][j] += A.a[i][k] * B.a[k][j]; 27 if(ret.a[i][j] >= mod) 28 ret.a[i][j] %= mod; 29 } 30 return ret; 31 } 32 matrix expo(matrix p, int k) 33 { 34 if(k == 1) return p; 35 matrix ret; 36 memset(ret.a, 0, sizeof(ret.a)); 37 for(int i = 0; i < maxn; ++i) 38 { 39 ret.a[i][i] = 1; 40 } 41 if(k == 0) return ret; 42 while(k) 43 { 44 if(k & 1) 45 ret = mul(p, ret); 46 p = mul(p, p); 47 k >>= 1; 48 } 49 return ret; 50 } 51 int main() 52 { 53 int k; 54 matrix m; 55 while(~scanf("%d", &k)) 56 { 57 if(k == -1) break; 58 if(!k) printf("0 "); 59 else if(k == 1) printf("1 "); 60 else 61 { 62 m.a[0][0] = m.a[0][1] = m.a[1][0] = 1; 63 m.a[1][1] = 0; 64 65 matrix ans = expo(m, k-1); 66 printf("%d ", ans.a[0][0]%mod); 67 } 68 } 69 return 0; 70 }
2、学习构造矩阵:
![](https://images.cnblogs.com/OutliningIndicators/ContractedBlock.gif)
1 /* 2 Problem: NYoj 301 3 Tips: 矩阵快速幂 构造矩阵 4 递推式: f(x)=a*f(x-2)+b*f(x-1)+c 5 | f(n-2) f(n-1) 1 | | a 0 0 | 6 | 0 0 0 | * | b 1 0 | 7 | 0 0 0 | | c 0 1 | 8 Date: 2015.8.1 9 */ 10 #include <iostream> 11 #include <cstdio> 12 #include <cstring> 13 using namespace std; 14 typedef long long ll; 15 const int mod = 1000007; 16 const int maxn = 3; 17 struct matrix 18 { 19 ll a[maxn][maxn]; 20 }; 21 matrix mul(matrix A, matrix B) 22 { 23 matrix ret; 24 memset(ret.a, 0, sizeof(ret.a)); 25 for(int i = 0; i < maxn; ++i) 26 for(int k = 0; k < maxn; ++k) 27 if(A.a[i][k]) 28 for(int j = 0; j < maxn; ++j) 29 { 30 ret.a[i][j] += A.a[i][k] * B.a[k][j]; 31 if(ret.a[i][j] >= mod) 32 { 33 ret.a[i][j] %= mod; 34 } 35 else if(ret.a[i][j] < 0) 36 { 37 ret.a[i][j] += mod; 38 } 39 } 40 return ret; 41 } 42 matrix expo(matrix p, int k) 43 { 44 if(k == 1) return p; 45 matrix ret; 46 memset(ret.a, 0, sizeof(ret.a)); 47 for(int i = 0; i < maxn; ++i) 48 { 49 ret.a[i][i] = 1; 50 } 51 if(k == 0) return ret; 52 while(k) 53 { 54 if(k & 1) 55 ret = mul(p, ret); 56 p = mul(p, p); 57 k >>= 1; 58 } 59 return ret; 60 } 61 62 int main() 63 { 64 ll f1, f2, a, b, c, n; 65 matrix m1, m2; 66 int T; scanf("%d", &T); 67 while(T--) 68 { 69 scanf("%lld%lld%lld%lld%lld%lld", &f1, &f2, &a, &b, &c, &n); 70 if(n == 1) printf("%lld ", (f1 + mod) % mod); 71 else if(n == 2) printf("%lld ", (f2 + mod) % mod); 72 else 73 { 74 memset(m1.a, 0, sizeof(m1.a)); 75 memset(m2.a, 0, sizeof(m2.a)); 76 m1.a[0][0] = f2, m1.a[0][1] = f1, m1.a[0][2] = 1; 77 m2.a[0][0] = b, m2.a[1][0] = a, m2.a[2][0] = c; 78 m2.a[0][1] = m2.a[2][2] = 1; 79 matrix ans = expo(m2, n-2); 80 ans = mul(m1, ans); 81 if(ans.a[0][0] >= mod) 82 { 83 ans.a[0][0] %= mod; 84 } 85 else if(ans.a[0][0] < 0) 86 { 87 ans.a[0][0] += mod; 88 } 89 printf("%lld ", ans.a[0][0]); 90 } 91 } 92 return 0; 93 }
3、等比矩阵:构造矩阵(110MS) S = A + A2 + A3 + ... + Ak = A(I + A(I + A(... (I+A)))) ;
即构造矩阵每次相乘得到A+I.
![](https://images.cnblogs.com/OutliningIndicators/ContractedBlock.gif)
1 /* 2 Tips: 矩阵快速幂 矩阵构造 3 | A 1 | | A 1 | | A^2 1+A | | A 1 | | A^3 1+A+(A^2) | 4 | 0 1 | * | 0 1 | = | 0 1 | * | 0 1 | = | 0 1 | 5 6 | A 1 |^(k+1) | A^k 1+A+(A^2)+...+(A^n)| 7 | 0 1 | = | 0 1 | 8 9 Date: 2015.8.1 10 */ 11 #include <iostream> 12 #include <cstdio> 13 #include <cstring> 14 using namespace std; 15 typedef long long ll; 16 int mod; 17 const int maxn = 100; 18 struct matrix 19 { 20 int a[maxn][maxn]; 21 }; 22 matrix add(matrix A, matrix B, int n) 23 { 24 matrix ret; 25 memset(ret.a, 0, sizeof(ret.a)); 26 for(int i = 0; i < n; ++i) 27 for(int k = 0; k < n; ++k) 28 { 29 ret.a[i][k] += A.a[i][k] + B.a[i][k]; 30 if(ret.a[i][k] >= mod) 31 ret.a[i][k] %= mod; 32 } 33 return ret; 34 } 35 matrix mul(matrix A, matrix B, int n) 36 { 37 matrix ret; 38 memset(ret.a, 0, sizeof(ret.a)); 39 for(int i = 0; i < n; ++i) 40 for(int k = 0; k < n; ++k) 41 if(A.a[i][k]) 42 for(int j = 0; j < n; ++j) 43 { 44 ret.a[i][j] += A.a[i][k] * B.a[k][j]; 45 if(ret.a[i][j] >= mod) 46 ret.a[i][j] %= mod; 47 } 48 return ret; 49 } 50 matrix expo(matrix p, int k, int n) 51 { 52 if(k == 1) return p; 53 matrix ret; 54 memset(ret.a, 0, sizeof(ret.a)); 55 for(int i = 0; i < n; ++i) 56 ret.a[i][i] = 1; 57 if(k == 0) return ret; 58 while(k) 59 { 60 if(k & 1) 61 ret = mul(p, ret, n); 62 p = mul(p, p, n); 63 k >>= 1; 64 } 65 return ret; 66 } 67 void print(matrix ans, int n) 68 { 69 for(int i = 0; i < n; i++) 70 { 71 for(int j = 0; j < n; j++) 72 { 73 if(j) printf(" "); 74 printf("%d", ans.a[i][j] % mod); 75 } 76 printf(" "); 77 } 78 } 79 int main() 80 { 81 int n, k, m; 82 matrix A, res; 83 scanf("%d%d%d", &n, &k, &m); 84 mod = m; 85 for(int i = 0; i < n; i++) 86 for(int j = 0; j < n; j++) 87 scanf("%d", &A.a[i][j]); 88 if(k == 1) 89 print(A, n); 90 else 91 { 92 for(int i = 0; i < n; i++) 93 A.a[i][i+n] = A.a[i+n][i+n] = 1; 94 95 res = expo(A, k+1, 2*n); 96 for(int i = 0; i < n; i++) 97 { 98 for(int j = 0; j < n; j++) 99 { 100 if(j) printf(" "); 101 if(i != j) printf("%d", res.a[i][j+n] % mod); 102 else printf("%d", (res.a[i][j+n] - 1 + mod) % mod); 103 } 104 printf(" "); 105 } 106 } 107 return 0; 108 }
此题还有其他解法,可视S = (A + A2 + ... + Ak/2) + (Ak/2+1 + Ak/2+2 + ... + Ak)
= (A + A2 + ... + Ak/2) + Ak/2(A + A2 + ... + Ak/2) + [Ak]
= (I + Ak/2)(A + A2 + ... + Ak/2) + [Ak]
递归求解。
这种方法效率当然不如上种,自己没有再写,此处贴上别人家的代码。
![](https://images.cnblogs.com/OutliningIndicators/ContractedBlock.gif)
1 //800MS 2 #include <iostream> 3 #include <cstdio> 4 #include <cstring> 5 using namespace std; 6 int m,n,K; 7 int a[30][30]; 8 class Matrix 9 { 10 public: 11 int num[30][30]; 12 Matrix(bool is=true) //初始化 13 { 14 memset(num,0,sizeof(num)); 15 if(is) 16 for(int i=0;i<n;i++) 17 num[i][i]=1; 18 } 19 void print() //输出函数 20 { 21 for(int i=0;i<n;++i) 22 { 23 printf("%d",num[i][0]); 24 for(int j=1;j<n;++j) 25 printf(" %d",num[i][j]); 26 printf(" "); 27 } 28 } 29 //重载乘法运算 30 friend Matrix& operator *(const Matrix& max1,const Matrix& max2) 31 { 32 Matrix tmp(false); //注意这里是false,即初始化的矩阵不是单位矩阵I 33 for(int i=0;i<n;++i) 34 for(int j=0;j<n;++j) 35 { 36 for(int k=0;k<n;++k) 37 tmp.num[i][j]+=(max1.num[i][k]*max2.num[k][j])%m; 38 tmp.num[i][j]%=m; 39 } 40 return tmp; 41 } 42 //重载+=运算 43 Matrix& operator +=(const Matrix& max) 44 { 45 for(int i=0;i<n;++i) 46 for(int j=0;j<n;++j) 47 num[i][j]=(num[i][j]+max.num[i][j])%m; 48 49 return *this; 50 } 51 }ans; 52 Matrix mul(Matrix A,int k) //求A^K 53 { 54 if(k==1) 55 return A; 56 Matrix tmp ; 57 while(k) 58 { 59 if(k&1) 60 tmp = tmp * A; 61 k>>=1; 62 A = A*A; 63 } 64 return tmp; 65 } 66 Matrix S(Matrix A,int k) //求 S[k] 67 { 68 if(k==1) 69 return A; 70 71 Matrix tmp ; 72 tmp += mul(A,k>>1); //求 (I + A^(k/2) ) 73 tmp = tmp*S(A,k>>1); //求 (I + A^(k/2) )*S[k/2] 74 if(k&1) //判断时候要加上 A^k 75 tmp+= mul(A,k); //S[k] = (I+A^(k/2)) * S[k/2] + {A^k} 76 return tmp; 77 } 78 79 int main() 80 { 81 int i,j,k; 82 scanf("%d %d %d",&n,&K,&m); 83 for( i=0;i<n;++i) 84 for( j=0;j<n;++j) 85 scanf("%d",&ans.num[i][j]); 86 ans = S(ans,K); 87 ans.print(); 88 89 return 0; 90 }
4、被这题坑大半天简直是想咬舌。
![](https://images.cnblogs.com/OutliningIndicators/ContractedBlock.gif)
1 /* 2 Problem: UVa 10006 3 Tips: 快速幂练手题 4 Date: 2015.8.2 5 TLE原因:理解能力简直是作死。 6 题意: 合数+对于任意的(!!全称量词啊不是存在T T)a满足 (a^n)%n == a 7 */ 8 9 #include <iostream> 10 #include <cstdio> 11 #include <cmath> 12 #include <cstring> 13 using namespace std; 14 typedef long long ll; 15 //const int mod = 1000007; 16 const int maxn = 65100; 17 int n; 18 bool pri[maxn]; 19 void get_pri() 20 { 21 memset(pri, true, sizeof(pri)); 22 int m = sqrt(maxn + 0.5); 23 for(int i = 2; i <= m; i++) 24 if(pri[i]) 25 for(int j = i*i; j < maxn; j += i) 26 pri[j] = false; 27 } 28 int expo(int x, int k, int mod) 29 { 30 if(k == 0) return 1%mod; 31 if(k == 1) return x%mod; 32 ll ret = 1; 33 while(k) 34 { 35 if(k & 1) ret = (ret*x)%mod; 36 x = ((ll)x*x)%mod; 37 k >>= 1; 38 } 39 return ret; 40 } 41 42 int main() 43 { 44 get_pri(); 45 while(~scanf("%d", &n) && n) 46 { 47 if(pri[n]) 48 { 49 printf("%d is normal. ", n); 50 continue ; 51 } 52 bool flag = true; 53 for(int a = 2; a < n; a++) 54 if(expo(a, n, n) != a) 55 { 56 flag = false; 57 break; 58 } 59 60 if(flag == true) printf("The number %d is a Carmichael number. ", n); 61 else printf("%d is normal. ", n); 62 } 63 return 0; 64 }