简单方阵矩乘法
SQUARE-MATRIX-MULTIPLY(A,B) 1 n = A.rows 2 let C be a new n*n natrix 3 for i = 1 to n 4 for j =1 to n 5 cij = 0 6 for k=1 to n 7 cij=cij+aik·bkj 8 return C 一个简单的分治算法 SQUARE-MATRIX-MULTIPLY-RECURSIVE(A,B) 1 n = A.rows 2 let C be a new n*n matrix 3 if n==1 4 c11=a11·b11 5 else partition A,B,and C as in equations (4.9) 6 C11=SQUARE-MATRIX-MULTIPLY-RECURSIVE(A11,B11) +SQUARE-MATRIX-MULTIPLY-RECURSIVE(A12,B21) 7 C12=SQUARE-MATRIX-MULTIPLY-RECURSIVE(A11,B12) +SQUARE-MATRIX-MULTIPLY-RECURSIVE(A12,B22) 8 C21=SQUARE-MATRIX-MULTIPLY-RECURSIVE(A21,B11) +SQUARE-MATRIX-MULTIPLY-RECURSIVE(A22,B21) 9 C22=SQUARE-MATRIX-MULTIPLY-RECURSIVE(A21,B12) +SQUARE-MATRIX-MULTIPLY-RECURSIVE(A22,B22) 10 return C 矩阵乘法的Strassen算法 SQUARE-MATRIX-STRASSEN-RECURSIVE(A,B) 1 n=A.rows 2 let C be a new n*n matrix 3 if n==1 4 c11=a11·b11 5 else partition A,B,and C as in equations(4-9) 6 S1=B12-B22; 7 S2=A11+A12; 8 S3=A21+A22; 9 S4=B21-B11; 10 S5=A11+A22; 11 S6=B11+B22; 12 S7=A12-A22; 13 S8=B21+B22; 14 S9=A11-A21; 15 S10=B11+B12; 16 P1=SQUARE-MATRIX-STRASSEN-RECURSIVE(A11,S1); 17 P2=SQUARE-MATRIX-STRASSEN-RECURSIVE(S2,B22); 18 P3=SQUARE-MATRIX-STRASSEN-RECURSIVE(S3,B11); 19 P4=SQUARE-MATRIX-STRASSEN-RECURSIVE(A22,S4); 20 P5=SQUARE-MATRIX-STRASSEN-RECURSIVE(S5,S6); 21 P6=SQUARE-MATRIX-STRASSEN-RECURSIVE(S7,S8); 22 P7=SQUARE-MATRIX-STRASSEN-RECURSIVE(S9,S10); 23 C11=P5+P4-P2+P6; 24 C12=P1+P2; 25 C21=P3+P4; 26 C22=P5+P1-P3-P7; 26 return C;
/*C++代码。书上给的分解矩阵的做法是用角标计算而不用建立新的对象,不过我并没有想到可以不用新建对象而进行递归的办法,所以这里还是和书上有些不一样的。另外因为演示,所以新建了个类,不过这个类并不稳定,仅作测试时了解功能就好*/
Matrix.h class SquareMatrix { public: SquareMatrix(); SquareMatrix(int **data,int rows); SquareMatrix(int rows); ~SquareMatrix(); int CreateSqMa(int rows); int SetData(int rows,int *data); int **iData; int iRows; friend SquareMatrix operator+(SquareMatrix A,SquareMatrix B); friend SquareMatrix operator-(SquareMatrix A,SquareMatrix B); int SprintSqMa(); }; Matrix.cpp #include <iostream> #include "Matrix.h" SquareMatrix::SquareMatrix() { } SquareMatrix::SquareMatrix(int rows) { this->CreateSqMa(rows); } SquareMatrix::SquareMatrix(int **data,int rows) { iData=data; iRows=rows; } SquareMatrix::~SquareMatrix() { } int SquareMatrix::CreateSqMa(int rows) { iRows = rows; iData =new int *[rows]; for (int i=0;i<iRows;i++) { iData[i]=new int [rows] ; for (int j=0;j<iRows;j++) { iData[i][j]=0; } } return 0; } int SquareMatrix::SetData(int rows,int *data) { int length=rows; for (int i = 0; i < length; i++) { for (int j = 0; j < length; j++) { iData[j][i]=data[i*rows+j]; } } iRows=rows; return 0; } SquareMatrix operator+(SquareMatrix A,SquareMatrix B) { SquareMatrix C(A.iRows); for(int i=0;i<B.iRows;i++) { for(int j=0;j<B.iRows;j++) { C.iData[i][j]=A.iData[i][j]+B.iData[i][j]; } } C.iRows=A.iRows; return C; } SquareMatrix operator-(SquareMatrix A,SquareMatrix B) { SquareMatrix C(A.iRows); for(int i=0;i<B.iRows;i++) { for(int j=0;j<B.iRows;j++) { C.iData[i][j]=A.iData[i][j]-B.iData[i][j]; } } return C; } int SquareMatrix::SprintSqMa() { for(int i=0;i<iRows;i++) { for(int j=0;j<iRows;j++) { std::cout<<iData[i][j]<<' '; if(j==(iRows-1)) { std::cout<<std::endl; } } } return 0; } MAIN.cpp #include <iostream> #include "Matrix.h" using namespace std; SquareMatrix SquareMatrixMultiply(SquareMatrix A,SquareMatrix B); SquareMatrix SquareMatrixMultiplyRecursive(SquareMatrix A,SquareMatrix B); SquareMatrix Strassen(SquareMatrix A,SquareMatrix B); int main() { SquareMatrix A,B,C; B.CreateSqMa(4); C.CreateSqMa(4); A.CreateSqMa(4); int arr[16]={1,2,1,1,1,1,1,1,1,1,1,1,1,1,1,1}; B.SetData(4,arr); A.SetData(4,arr); //C=A+B; //C=SquareMatrixMultiply(A,B); C=SquareMatrixMultiplyRecursive(A,B); //C=Strassen(A,B); cout<<"算法导论4.2矩阵乘法Strassen算法"<<endl; A.SprintSqMa(); cout<<endl; B.SprintSqMa(); cout<<endl; C.SprintSqMa(); cout<<endl; system("pause"); return 0; } SquareMatrix SquareMatrixMultiply(SquareMatrix A,SquareMatrix B) { SquareMatrix C(A.iRows); int n=A.iRows; for (int i=0;i<n;i++) { for(int j=0;j<n;j++) { for(int k=0;k<n;k++) { C.iData[i][j]=C.iData[i][j]+A.iData[i][k]*B.iData[k][j]; } } } return C; } SquareMatrix SquareMatrixMultiplyRecursive(SquareMatrix A,SquareMatrix B) { SquareMatrix C(A.iRows); int n=A.iRows; if(n==1) { C.iData[0][0]=A.iData[0][0]*B.iData[0][0]; } else { int rows_n=n/2; SquareMatrix A11(rows_n),A12(rows_n), A21(rows_n),A22(rows_n), B11(rows_n),B12(rows_n), B21(rows_n),B22(rows_n), C11(rows_n),C12(rows_n), C21(rows_n),C22(rows_n); for (int i=0;i<rows_n;i++) { for(int j=0;j<rows_n;j++) { A11.iData[i][j]=A.iData[i][j]; A12.iData[i][j]=A.iData[i][j+rows_n]; A21.iData[i][j]=A.iData[i+rows_n][j]; A22.iData[i][j]=A.iData[i+rows_n][j+rows_n]; B11.iData[i][j]=B.iData[i][j]; B12.iData[i][j]=B.iData[i][j+rows_n]; B21.iData[i][j]=B.iData[i+rows_n][j]; B22.iData[i][j]=B.iData[i+rows_n][j+rows_n]; } } C11=SquareMatrixMultiplyRecursive(A11,B11) +SquareMatrixMultiplyRecursive(A12,B21); C12=SquareMatrixMultiplyRecursive(A11,B12) +SquareMatrixMultiplyRecursive(A12,B22); C21=SquareMatrixMultiplyRecursive(A21,B11) +SquareMatrixMultiplyRecursive(A22,B21); C22=SquareMatrixMultiplyRecursive(A21,B12) +SquareMatrixMultiplyRecursive(A22,B22); for (int i=0;i<rows_n;i++) { for(int j=0;j<rows_n;j++) { C.iData[i][j]=C11.iData[i][j]; C.iData[i][j+rows_n]=C12.iData[i][j]; C.iData[i+rows_n][j]=C21.iData[i][j]; C.iData[i+rows_n][j+rows_n]=C22.iData[i][j]; } } } return C; } SquareMatrix Strassen(SquareMatrix A,SquareMatrix B) { SquareMatrix C(A.iRows); int n=A.iRows; if(n==1) { C.iData[0][0]=A.iData[0][0]*B.iData[0][0]; } else { int rows_n=n/2; SquareMatrix A11(rows_n),A12(rows_n), A21(rows_n),A22(rows_n), B11(rows_n),B12(rows_n), B21(rows_n),B22(rows_n), C11(rows_n),C12(rows_n), C21(rows_n),C22(rows_n), S1,S2,S3,S4,S5,S6,S7,S8,S9,S10, P1,P2,P3,P4,P5,P6,P7; for (int i=0;i<rows_n;i++) { for(int j=0;j<rows_n;j++) { A11.iData[i][j]=A.iData[i][j]; A12.iData[i][j]=A.iData[i][j+rows_n]; A21.iData[i][j]=A.iData[i+rows_n][j]; A22.iData[i][j]=A.iData[i+rows_n][j+rows_n]; B11.iData[i][j]=B.iData[i][j]; B12.iData[i][j]=B.iData[i][j+rows_n]; B21.iData[i][j]=B.iData[i+rows_n][j]; B22.iData[i][j]=B.iData[i+rows_n][j+rows_n]; } } S1=B12-B22; S2=A11+A12; S3=A21+A22; S4=B21-B11; S5=A11+A22; S6=B11+B22; S7=A12-A22; S8=B21+B22; S9=A11-A21; S10=B11+B12; P1=Strassen(A11,S1); P2=Strassen(S2,B22); P3=Strassen(S3,B11); P4=Strassen(A22,S4); P5=Strassen(S5,S6); P6=Strassen(S7,S8); P7=Strassen(S9,S10); C11=P5+P4-P2+P6; C12=P1+P2; C21=P3+P4; C22=P5+P1-P3-P7; for (int i=0;i<rows_n;i++) { for(int j=0;j<rows_n;j++) { C.iData[i][j]=C11.iData[i][j]; C.iData[i][j+rows_n]=C12.iData[i][j]; C.iData[i+rows_n][j]=C21.iData[i][j]; C.iData[i+rows_n][j+rows_n]=C22.iData[i][j]; } } } return C; }