zoukankan      html  css  js  c++  java
  • n阶方阵乘法straseen

    原理:分块矩阵乘法,进行8次矩阵乘法,时间复杂度为 $ heta(n^3) = heta(n^{lg{8}}) $ , 改进后仅需要7次乘法, 时间复杂度为 $ heta(n^{lg{7}})$
    具体推到见算法导论中利用主定理推导时间复杂度

    def matrix_divide(A):
        rows = len(A)
        mid = rows // 2
        A11 = [[0]*mid for _ in range(mid)]
        A12 = [[0]*mid for _ in range(mid)]
        A21 = [[0]*mid for _ in range(mid)]
        A22 = [[0]*mid for _ in range(mid)]
    
        for i in range(mid):
            for j in range(mid):
                A11[i][j] = A[i][j]
                A12[i][j] = A[i][mid+j]
                A21[i][j] = A[mid+i][j]
                A22[i][j] = A[mid+i][mid+j]
        return A11, A12, A21, A22
    
    def matrix_add(A, B):
        rows = len(A)
        C = [[0]*rows for _ in range(rows)]
        for i in range(rows):
            for j in range(rows):
                C[i][j] = A[i][j] + B[i][j]
        return C
    
    def matrix_sub(A, B):
        rows = len(A)
        C = [[0]*rows for _ in range(rows)]
        for i in range(rows):
            for j in range(rows):
                C[i][j] = A[i][j] - B[i][j]
        return C
    
    
    def matrix_merge(C11, C12, C21, C22):
        rows = len(C11)
        n = rows * 2
        C = [[0]*n for _ in range(n)]
        for i in range(rows):
            for j in range(rows):
                C[i][j] = C11[i][j]
                C[i][rows+j] = C12[i][j]
                C[rows+i][j] = C21[i][j]
                C[rows+i][rows+j] = C22[i][j]
        return C
    
    
    def strassen(A, B):
        n = len(A)
        C = [[0] for _ in range(n)]
        if n == 1:
            C[0][0] = A[0][0]*B[0][0]
            return C
        A11, A12, A21, A22 = matrix_divide(A)
        B11, B12, B21, B22 = matrix_divide(B)
    
        S1 = matrix_sub(B12, B22)
        S2 = matrix_add(A11, A12)
        S3 = matrix_add(A21, A22)
        S4 = matrix_sub(B21, B11)
        S5 = matrix_add(A11, A22)
        S6 = matrix_add(B11, B22)
        S7 = matrix_sub(A12, A22)
        S8 = matrix_add(B21, B22)
        S9 = matrix_sub(A11, A21)
        S10 = matrix_add(B11, B12)
        
        P1 = strassen(A11, S1)
        P2 = strassen(S2, B22)
        P3 = strassen(S3, B11)
        P4 = strassen(A22, S4)
        P5 = strassen(S5, S6)
        P6 = strassen(S7, S8)
        P7 = strassen(S9, S10)
    
        C11 = matrix_add(P5, matrix_sub(P4, matrix_sub(P2, P6)))
        C12 = matrix_add(P1, P2)
        C21 = matrix_add(P3, P4)
        C22 = matrix_add(P5, matrix_sub(P1, matrix_add(P3, P7)))
        
        return matrix_merge(C11, C12, C21, C22)
    def main():
        A = [[1,1,1,1],[2,2,2,2],[3,3,3,3],[4,4,4,4]]
        B = [[5,5,5,5],[6,6,6,6],[7,7,7,7],[8,8,8,8]]
        C = strassen(A, B)
        print(C)
    if __name__ == '__main__':
        main()
    
  • 相关阅读:
    c++ stl string char* 向 string 转换的问题
    不要在疲惫中工作
    今天
    悠然自得
    忙与闲
    <转>LuaTinker的bug和缺陷
    匿名管道
    SetWindowHookEx 做消息响应
    最近工作
    实现网页页面跳转的几种方法(meta标签、js实现、php实现)
  • 原文地址:https://www.cnblogs.com/vito_wang/p/10806816.html
Copyright © 2011-2022 走看看