zoukankan      html  css  js  c++  java
  • 第四章 分治策略 4.2 矩阵乘法的Strassen算法

    package chap04_Divide_And_Conquer;
    
    import static org.junit.Assert.*;
    
    import java.util.Arrays;
    
    import org.junit.Test;
    
    /**
     * 矩阵相乘的算法
     * 
     * @author xiaojintao
     * 
     */
    public class MatrixOperation {
        /**
         * 普通的矩阵相乘算法,c=a*b。其中,a、b都是n*n的方阵
         * 
         * @param a
         * @param b
         * @return c
         */
        static int[][] matrixMultiplicationByCommonMethod(int[][] a, int[][] b) {
            int n = a.length;
            int[][] c = new int[n][n];
            for (int i = 0; i < n; i++) {
                for (int j = 0; j < n; j++) {
                    c[i][j] = 0;
                    for (int k = 0; k < n; k++) {
                        c[i][j] += a[i][k] * b[k][j];
                    }
                }
            }
            return c;
        }
    
        /**
         * strassen 算法求矩阵乘法 n为2的幂
         * 
         * @param a
         * @param b
         * @return
         */
        static int[][] matrixMultiplicationByStrassen(int[][] a, int[][] b) {
            int n = a.length;
            if (n == 1) {
                int[][] c = new int[1][1];
                c[0][0] = a[0][0] * b[0][0];
                return c;
            }
            int m = n / 2;
            int[][] a11, a12, a21, a22, b11, b12, b21, b22;
            int[][] c = new int[n][n];
            a11 = new int[m][m];
            a12 = new int[m][m];
            a21 = new int[m][m];
            a22 = new int[m][m];
            b11 = new int[m][m];
            b12 = new int[m][m];
            b21 = new int[m][m];
            b22 = new int[m][m];
    
            for (int i = 0; i < m; i++) {
                for (int j = 0; j < m; j++) {
                    a11[i][j] = a[i][j];
                }
            }
            for (int i = 0; i < m; i++) {
                for (int j = 0; j < m; j++) {
                    b11[i][j] = b[i][j];
                }
            }
            for (int i = 0; i < m; i++) {
                for (int j = m; j < n; j++) {
                    a12[i][j - m] = a[i][j];
                }
            }
            for (int i = 0; i < m; i++) {
                for (int j = m; j < n; j++) {
                    b12[i][j - m] = b[i][j];
                }
            }
            for (int i = m; i < n; i++) {
                for (int j = 0; j < m; j++) {
                    a21[i - m][j] = a[i][j];
                }
            }
            for (int i = m; i < n; i++) {
                for (int j = 0; j < m; j++) {
                    b21[i - m][j] = b[i][j];
                }
            }
            for (int i = m; i < n; i++) {
                for (int j = m; j < n; j++) {
                    a22[i - m][j - m] = a[i][j];
                }
            }
            for (int i = m; i < n; i++) {
                for (int j = m; j < n; j++) {
                    b22[i - m][j - m] = b[i][j];
                }
            }
            int[][] s1 = matrixMinus(b12, b22);
            int[][] s2 = matrixAdd(a11, a12);
            int[][] s3 = matrixAdd(a21, a22);
            int[][] s4 = matrixMinus(b21, b11);
            int[][] s5 = matrixAdd(a11, a22);
            int[][] s6 = matrixAdd(b11, b22);
            int[][] s7 = matrixMinus(a12, a22);
            int[][] s8 = matrixAdd(b21, b22);
            int[][] s9 = matrixMinus(a11, a21);
            int[][] s10 = matrixAdd(b11, b12);
    
            int[][] p1 = matrixMultiplicationByStrassen(a11, s1);
            int[][] p2 = matrixMultiplicationByStrassen(s2, b22);
            int[][] p3 = matrixMultiplicationByStrassen(s3, b11);
            int[][] p4 = matrixMultiplicationByStrassen(a22, s4);
            int[][] p5 = matrixMultiplicationByStrassen(s5, s6);
            int[][] p6 = matrixMultiplicationByStrassen(s7, s8);
            int[][] p7 = matrixMultiplicationByStrassen(s9, s10);
    
            int[][] t1, t2, t3;
            t1 = matrixAdd(p5, p4);
            t2 = matrixMinus(t1, p2);
            int[][] c11 = matrixAdd(t2, p6);
            int[][] c12 = matrixAdd(p1, p2);
            int[][] c21 = matrixAdd(p3, p4);
            t1 = matrixAdd(p5, p1);
            t2 = matrixMinus(t1, p3);
            int[][] c22 = matrixMinus(t2, p7);
            c = matrixConbine(c11, c12, c21, c22);
            return c;
        }
    
        /**
         * 矩阵加法 c=a+b
         * 
         * @param a
         * @param b
         * @return
         */
        static int[][] matrixAdd(int[][] a, int[][] b) {
            int n = a.length;
            int[][] c = new int[n][n];
            for (int i = 0; i < n; i++) {
                for (int j = 0; j < n; j++) {
                    c[i][j] = a[i][j] + b[i][j];
                }
            }
            return c;
        }
    
        /**
         * 矩阵减法 c=a-b
         * 
         * @param a
         * @param b
         * @return
         */
        static int[][] matrixMinus(int[][] a, int[][] b) {
            int n = a.length;
            int[][] c = new int[n][n];
            for (int i = 0; i < n; i++) {
                for (int j = 0; j < n; j++) {
                    c[i][j] = a[i][j] - b[i][j];
                }
            }
            return c;
        }
    
        /**
         * 将矩阵的四个部分组合
         * 
         * @param t11
         * @param t12
         * @param t21
         * @param t22
         * @return
         */
        protected static int[][] matrixConbine(int[][] t11, int[][] t12,
                int[][] t21, int[][] t22) {
            int n = t11.length;
            int m = 2 * n;
            int[][] c = new int[m][m];
            for (int i = 0; i < n; i++) {
                for (int j = 0; j < n; j++) {
                    c[i][j] = t11[i][j];
                }
            }
            for (int i = 0; i < n; i++) {
                for (int j = 0; j < n; j++) {
                    c[i][j + n] = t12[i][j];
                }
            }
            for (int i = 0; i < n; i++) {
                for (int j = 0; j < n; j++) {
                    c[i + n][j] = t21[i][j];
                }
            }
            for (int i = 0; i < n; i++) {
                for (int j = 0; j < n; j++) {
                    c[i + n][j + n] = t22[i][j];
                }
            }
            return c;
        }
    
        @Test
        public void testName() throws Exception {
            // int[][] a = { { 1, 2, 3 }, { 4, 5, 6 }, { 7, 8, 9 } };
            // int[][] b = { { 1, 3, 5 }, { 2, 4, 6 }, { 9, 8, 7 } };
            // int[][] c = commonMatrixMultiplication(a, b);
            // int[][] c = matrixAdd(a, b);
    
            int[][] m = { { 1, 2, 3, 4 }, { 5, 6, 7, 8 }, { 9, 10, 11, 12 },
                    { 13, 14, 15, 16 } };
            int[][] n = { { 1, 3, 5, 7 }, { 2, 4, 6, 8 }, { 4, 3, 2, 1 },
                    { 9, 8, 7, 6 } };
    
            int[][] c = matrixMultiplicationByStrassen(m, n);
            System.out.println(Arrays.deepToString(c));
            int[][] d = matrixMultiplicationByCommonMethod(m, n);
            System.out.println(Arrays.deepToString(d));
        }
    }

    暴力求解复杂度为O(n3),Strassen算法为O(n log7)

  • 相关阅读:
    豆瓣书籍数据采集
    动画精灵与碰撞检测
    图形
    模块
    对象
    函数
    列表与字典
    python 感悟
    SqlServer自动备份数据库(没有sql代理服务的情况下)
    关于AD获取成员隶属于哪些组InvokeGet("memberOf")的问题
  • 原文地址:https://www.cnblogs.com/xiaojintao/p/3776407.html
Copyright © 2011-2022 走看看