zoukankan      html  css  js  c++  java
  • 【算法导论C++代码】Strassen算法

    简单方阵矩乘法

    SQUARE-MATRIX-MULTIPLY(A,B)
    
    1  n = A.rows
    
    2  let C be a new n*n natrix
    
    3  for  i = 1 to n
    
    4 for j =1 to n
    
    5 cij = 0
    
    6 for k=1 to n
    
    7 cij=cij+aik·bkj
    
    8 return C
    
    一个简单的分治算法
    
    SQUARE-MATRIX-MULTIPLY-RECURSIVE(A,B)
    
    1 n = A.rows
    
    2 let C be a new n*n matrix
    
    3 if n==1
    
    4 c11=a11·b11
    
    5 else partition A,B,and C as in equations (4.9)
    
    6  C11=SQUARE-MATRIX-MULTIPLY-RECURSIVE(A11,B11)
    
          +SQUARE-MATRIX-MULTIPLY-RECURSIVE(A12,B21)
    
    7   C12=SQUARE-MATRIX-MULTIPLY-RECURSIVE(A11,B12)
    
          +SQUARE-MATRIX-MULTIPLY-RECURSIVE(A12,B22)
    
     
    
    8 C21=SQUARE-MATRIX-MULTIPLY-RECURSIVE(A21,B11)
    
          +SQUARE-MATRIX-MULTIPLY-RECURSIVE(A22,B21)
    
    9 C22=SQUARE-MATRIX-MULTIPLY-RECURSIVE(A21,B12)
    
          +SQUARE-MATRIX-MULTIPLY-RECURSIVE(A22,B22)
    
    10 return C
    
     
    
    矩阵乘法的Strassen算法
    
    SQUARE-MATRIX-STRASSEN-RECURSIVE(A,B)
    
    1 n=A.rows
    
    2 let C be a new n*n matrix
    
    3 if n==1
    
    4    c11=a11·b11
    
    5 else partition A,B,and C as in equations(4-9)
    
    6       S1=B12-B22;
    
    7   S2=A11+A12;
    
    8   S3=A21+A22;
    
    9   S4=B21-B11;
    
    10   S5=A11+A22;
    
    11   S6=B11+B22;
    
    12   S7=A12-A22;
    
    13   S8=B21+B22;
    
    14   S9=A11-A21;
    
    15   S10=B11+B12;
    
    16   P1=SQUARE-MATRIX-STRASSEN-RECURSIVE(A11,S1);
    
    17   P2=SQUARE-MATRIX-STRASSEN-RECURSIVE(S2,B22);
    
    18   P3=SQUARE-MATRIX-STRASSEN-RECURSIVE(S3,B11);
    
    19   P4=SQUARE-MATRIX-STRASSEN-RECURSIVE(A22,S4);
    
    20   P5=SQUARE-MATRIX-STRASSEN-RECURSIVE(S5,S6);
    
    21   P6=SQUARE-MATRIX-STRASSEN-RECURSIVE(S7,S8);
    
    22   P7=SQUARE-MATRIX-STRASSEN-RECURSIVE(S9,S10);
    
    23   C11=P5+P4-P2+P6;
    
    24   C12=P1+P2;
    
    25   C21=P3+P4;
    
    26   C22=P5+P1-P3-P7;
    
    26 return C;

    /*C++代码。书上给的分解矩阵的做法是用角标计算而不用建立新的对象,不过我并没有想到可以不用新建对象而进行递归的办法,所以这里还是和书上有些不一样的。另外因为演示,所以新建了个类,不过这个类并不稳定,仅作测试时了解功能就好*/
    Matrix.h
    class SquareMatrix
    {
    public:
        SquareMatrix();
        SquareMatrix(int **data,int rows);
        SquareMatrix(int rows);
        ~SquareMatrix();
        int CreateSqMa(int rows);
        int SetData(int rows,int *data);
        int    **iData;
        int iRows;
        friend    SquareMatrix operator+(SquareMatrix A,SquareMatrix B);
        friend    SquareMatrix operator-(SquareMatrix A,SquareMatrix B);
        int SprintSqMa();
    };
    Matrix.cpp
    #include <iostream>
    #include "Matrix.h"
    SquareMatrix::SquareMatrix()
    {
    
    }
    SquareMatrix::SquareMatrix(int rows)
    {
        this->CreateSqMa(rows);
    }
    SquareMatrix::SquareMatrix(int **data,int rows)
    {
        iData=data;
        iRows=rows;
    }
    
    SquareMatrix::~SquareMatrix()
    {
    
    }
    int SquareMatrix::CreateSqMa(int rows)
    {
        iRows = rows;
        iData =new int *[rows];
        for (int i=0;i<iRows;i++)
        {
            iData[i]=new int [rows] ;
            for (int j=0;j<iRows;j++)
            {
                iData[i][j]=0;
            }
        }
        return 0;
    }
    int SquareMatrix::SetData(int rows,int *data)
    {
        int length=rows;
        for (int i = 0; i < length; i++)
        {
            for (int j = 0; j < length; j++)
            {
                iData[j][i]=data[i*rows+j];
            }
        }
        iRows=rows;
        return 0;
    }
    
    SquareMatrix operator+(SquareMatrix A,SquareMatrix B)
    {
        SquareMatrix C(A.iRows);
    
        for(int i=0;i<B.iRows;i++)
        {
            for(int j=0;j<B.iRows;j++)
            {
                C.iData[i][j]=A.iData[i][j]+B.iData[i][j];
            }
        }
        C.iRows=A.iRows;
        return C;
    }
    SquareMatrix operator-(SquareMatrix A,SquareMatrix B)
    {
        SquareMatrix C(A.iRows);
        for(int i=0;i<B.iRows;i++)
        {
            for(int j=0;j<B.iRows;j++)
            {
                C.iData[i][j]=A.iData[i][j]-B.iData[i][j];
            }
        }
        return C;
    }
    int SquareMatrix::SprintSqMa()
    {
        for(int i=0;i<iRows;i++)
        {
            for(int j=0;j<iRows;j++)
            {
                std::cout<<iData[i][j]<<' ';
                if(j==(iRows-1))
                {
                    std::cout<<std::endl;
                }
            }
        }
        return 0;
    }
    
    MAIN.cpp
    #include <iostream>
    #include "Matrix.h"
    using namespace std;
    SquareMatrix SquareMatrixMultiply(SquareMatrix A,SquareMatrix B);
    SquareMatrix SquareMatrixMultiplyRecursive(SquareMatrix A,SquareMatrix B);
    SquareMatrix Strassen(SquareMatrix A,SquareMatrix B);
    int main()
    {
        SquareMatrix A,B,C;
    
        B.CreateSqMa(4);
        C.CreateSqMa(4);
        A.CreateSqMa(4);
        int arr[16]={1,2,1,1,1,1,1,1,1,1,1,1,1,1,1,1};
        B.SetData(4,arr);
        A.SetData(4,arr);
    
        //C=A+B;
        //C=SquareMatrixMultiply(A,B);
        C=SquareMatrixMultiplyRecursive(A,B);
        //C=Strassen(A,B);
        cout<<"算法导论4.2矩阵乘法Strassen算法"<<endl;
    
        A.SprintSqMa();
        cout<<endl;
    
        B.SprintSqMa();
        cout<<endl;
    
        C.SprintSqMa();
        cout<<endl;
    
        system("pause");
        return 0;
    }
    
    SquareMatrix SquareMatrixMultiply(SquareMatrix A,SquareMatrix B)
    {
        SquareMatrix C(A.iRows);
        int n=A.iRows;
        for (int i=0;i<n;i++)
        {
            for(int j=0;j<n;j++)
            {            
                for(int k=0;k<n;k++)
                {
                    C.iData[i][j]=C.iData[i][j]+A.iData[i][k]*B.iData[k][j];
                }
            }
        }
        return C;
    }
    
    SquareMatrix SquareMatrixMultiplyRecursive(SquareMatrix A,SquareMatrix B)
    {
        SquareMatrix C(A.iRows);
        int n=A.iRows;
        if(n==1)
        {
            C.iData[0][0]=A.iData[0][0]*B.iData[0][0];
        }
        else
        {
            int rows_n=n/2;
            SquareMatrix A11(rows_n),A12(rows_n),
                A21(rows_n),A22(rows_n),
                B11(rows_n),B12(rows_n),
                B21(rows_n),B22(rows_n),
                C11(rows_n),C12(rows_n),
                C21(rows_n),C22(rows_n);
            for (int i=0;i<rows_n;i++)
            {
                for(int j=0;j<rows_n;j++)
                {
                    A11.iData[i][j]=A.iData[i][j];
                    A12.iData[i][j]=A.iData[i][j+rows_n];
                    A21.iData[i][j]=A.iData[i+rows_n][j];
                    A22.iData[i][j]=A.iData[i+rows_n][j+rows_n];
                    B11.iData[i][j]=B.iData[i][j];
                    B12.iData[i][j]=B.iData[i][j+rows_n];
                    B21.iData[i][j]=B.iData[i+rows_n][j];
                    B22.iData[i][j]=B.iData[i+rows_n][j+rows_n];
                }
            }
    
            C11=SquareMatrixMultiplyRecursive(A11,B11)
                +SquareMatrixMultiplyRecursive(A12,B21);
    
            C12=SquareMatrixMultiplyRecursive(A11,B12)
                +SquareMatrixMultiplyRecursive(A12,B22);
    
            C21=SquareMatrixMultiplyRecursive(A21,B11)
                +SquareMatrixMultiplyRecursive(A22,B21);
    
            C22=SquareMatrixMultiplyRecursive(A21,B12)
                +SquareMatrixMultiplyRecursive(A22,B22);
    
            for (int i=0;i<rows_n;i++)
            {
                for(int j=0;j<rows_n;j++)
                {
                    C.iData[i][j]=C11.iData[i][j];
                    C.iData[i][j+rows_n]=C12.iData[i][j];
                    C.iData[i+rows_n][j]=C21.iData[i][j];
                    C.iData[i+rows_n][j+rows_n]=C22.iData[i][j];
                }
            }
        }
        return C;
    }
    
    SquareMatrix Strassen(SquareMatrix A,SquareMatrix B)
    {    
        SquareMatrix C(A.iRows);
        int n=A.iRows;
        if(n==1)
        {
            C.iData[0][0]=A.iData[0][0]*B.iData[0][0];
        }
        else
        {
            int rows_n=n/2;
            SquareMatrix A11(rows_n),A12(rows_n),
                A21(rows_n),A22(rows_n),
                B11(rows_n),B12(rows_n),
                B21(rows_n),B22(rows_n),
                C11(rows_n),C12(rows_n),
                C21(rows_n),C22(rows_n),
                S1,S2,S3,S4,S5,S6,S7,S8,S9,S10,
                P1,P2,P3,P4,P5,P6,P7;
            for (int i=0;i<rows_n;i++)
            {
                for(int j=0;j<rows_n;j++)
                {
                    A11.iData[i][j]=A.iData[i][j];
                    A12.iData[i][j]=A.iData[i][j+rows_n];
                    A21.iData[i][j]=A.iData[i+rows_n][j];
                    A22.iData[i][j]=A.iData[i+rows_n][j+rows_n];
                    B11.iData[i][j]=B.iData[i][j];
                    B12.iData[i][j]=B.iData[i][j+rows_n];
                    B21.iData[i][j]=B.iData[i+rows_n][j];
                    B22.iData[i][j]=B.iData[i+rows_n][j+rows_n];
                }
            }
            S1=B12-B22;
            S2=A11+A12;
            S3=A21+A22;
            S4=B21-B11;
            S5=A11+A22;
            S6=B11+B22;
            S7=A12-A22;
            S8=B21+B22;
            S9=A11-A21;
            S10=B11+B12;
    
            P1=Strassen(A11,S1);
            P2=Strassen(S2,B22);
            P3=Strassen(S3,B11);
            P4=Strassen(A22,S4);
            P5=Strassen(S5,S6);
            P6=Strassen(S7,S8);
            P7=Strassen(S9,S10);
    
            C11=P5+P4-P2+P6;
            C12=P1+P2;
            C21=P3+P4;
            C22=P5+P1-P3-P7;
    
            for (int i=0;i<rows_n;i++)
            {
                for(int j=0;j<rows_n;j++)
                {
                    C.iData[i][j]=C11.iData[i][j];
                    C.iData[i][j+rows_n]=C12.iData[i][j];
                    C.iData[i+rows_n][j]=C21.iData[i][j];
                    C.iData[i+rows_n][j+rows_n]=C22.iData[i][j];
                }
            }
        }
        return C;
    }
  • 相关阅读:
    mktemp -t -d用法
    使用getopts处理输入参数
    linux中$1的意思
    linux中的set -e 与set -o pipefail
    在windows 7 和linux上安装xlwt和xlrd
    nginx map使用方法
    Linux crontab下关于使用date命令和sudo命令的坑
    东哥讲义
    ldapsearch使用
    date 命令之日期和秒数转换
  • 原文地址:https://www.cnblogs.com/fastcam/p/4778045.html
Copyright © 2011-2022 走看看