zoukankan      html  css  js  c++  java
  • 矩阵乘法 之 strassen 算法

    一般情况下矩阵乘法需要三个for循环,时间复杂度为O(n^3),现在我们将矩阵分块
    这里写图片描述

    一般算法需要八次乘法
    r = a * e + b * g ;
    s = a * f + b * h ;
    t = c * e + d * g;
    u = c * f + d * h;

    strassen将其变成7次乘法,因为大家都知道乘法比加减法消耗更多,所有时间复杂更高!
    strassen的处理是:
    令:
    p1 = a * ( f - h )
    p2 = ( a + b ) * h
    p3 = ( c +d ) * e
    p4 = d * ( g - e )
    p5 = ( a + d ) * ( e + h )
    p6 = ( b - d ) * ( g + h )
    p7 = ( a - c ) * ( e + f )

    那么我们可以知道:
    r = p5 + p4 + p6 - p2
    s = p1 + p2
    t = p3 + p4
    u = p5 + p1 - p3 - p7

    我们可以看到上面只有7次乘法和多次加减法,最终达到降低复杂度为O( n^lg7 ) ~= O( n^2.81 );
    代码实现如下:

    // strassen 算法:将矩阵相乘的复杂度降到O(n^lg7) ~= O(n^2.81)  
    // 原理是将8次乘法减少到7次的处理  
    // 现在理论上的最好的算法是O(n^2,367),仅仅是理论上的而已  
    //  
    //  
    // 下面的代码仅仅是简单的实例而已,不必较真哦,呵呵~  
    // 下面的空间可以优化的,此处就不麻烦了~  
    
    #include <stdio.h>  
    
    #define  N  10  
    
    //matrix + matrix  
    void plus( int t[N/2][N/2], int r[N/2][N/2], int s[N/2][N/2] )  
    {  
        int i, j;  
        for( i = 0; i < N / 2; i++ )  
        {  
            for( j = 0; j < N / 2; j++ )  
            {  
                t[i][j] = r[i][j] + s[i][j];  
            }  
        }  
    }  
    
    //matrix - matrix  
    void minus( int t[N/2][N/2], int r[N/2][N/2], int s[N/2][N/2] )  
    {  
        int i, j;  
        for( i = 0; i < N / 2; i++ )  
        {  
            for( j = 0; j < N / 2; j++ )  
            {  
                t[i][j] = r[i][j] - s[i][j];  
            }  
        }  
    }  
    
    //matrix * matrix  
    void mul( int t[N/2][N/2], int r[N/2][N/2], int s[N/2][N/2]  )  
    {  
        int i, j, k;  
        for( i = 0; i < N / 2; i++ )  
        {  
            for( j = 0; j < N / 2; j++ )  
            {  
                t[i][j] = 0;  
                for( k = 0; k < N / 2; k++ )  
                {  
                    t[i][j] += r[i][k] * s[k][j];  
                }  
            }  
        }  
    }  
    
    int main()  
    {  
        int i, j, k;  
        int mat[N][N];  
        int m1[N][N];  
        int m2[N][N];  
        int a[N/2][N/2],b[N/2][N/2],c[N/2][N/2],d[N/2][N/2];  
        int e[N/2][N/2],f[N/2][N/2],g[N/2][N/2],h[N/2][N/2];  
        int p1[N/2][N/2],p2[N/2][N/2],p3[N/2][N/2],p4[N/2][N/2];  
        int p5[N/2][N/2],p6[N/2][N/2],p7[N/2][N/2];  
        int r[N/2][N/2], s[N/2][N/2], t[N/2][N/2], u[N/2][N/2], t1[N/2][N/2], t2[N/2][N/2];  
    
    
        printf("
    Input the first matrix...:
    ");  
        for( i = 0; i < N; i++ )  
        {  
            for( j = 0; j < N; j++ )  
            {  
                scanf("%d", &m1[i][j]);  
            }  
        }  
    
        printf("
    Input the second matrix...:
    ");  
        for( i = 0; i < N; i++ )  
        {  
            for( j = 0; j < N; j++ )  
            {  
                scanf("%d", &m2[i][j]);  
            }  
        }  
    
        // a b c d e f g h  
        for( i = 0; i < N / 2; i++ )  
        {  
            for( j = 0; j < N / 2; j++ )  
            {  
                a[i][j] = m1[i][j];  
                b[i][j] = m1[i][j + N / 2];  
                c[i][j] = m1[i + N / 2][j];  
                d[i][j] = m1[i + N / 2][j + N / 2];  
                e[i][j] = m2[i][j];  
                f[i][j] = m2[i][j + N / 2];  
                g[i][j] = m2[i + N / 2][j];  
                h[i][j] = m2[i + N / 2][j + N / 2];  
            }  
        }  
    
        //p1  
        minus( r, f, h );  
        mul( p1, a, r );   
    
        //p2  
        plus( r, a, b );  
        mul( p2, r, h );  
    
        //p3  
        plus( r, c, d );  
        mul( p3, r, e );  
    
        //p4  
        minus( r, g, e );  
        mul( p4, d, r );  
    
        //p5  
        plus( r, a, d );  
        plus( s, e, f );  
        mul( p5, r, s );  
    
        //p6  
        minus( r, b, d );  
        plus( s, g, h );  
        mul( p6, r, s );  
    
        //p7  
        minus( r, a, c );  
        plus( s, e, f );  
        mul( p7, r, s );  
    
        //r = p5 + p4 - p2 + p6  
        plus( t1, p5, p4 );  
        minus( t2, t1, p2 );  
        plus( r, t2, p6 );  
    
        //s = p1 + p2  
        plus( s, p1, p2 );  
    
        //t = p3 + p4  
        plus( t, p3, p4 );  
    
        //u = p5 + p1 - p3 - p7 = p5 + p1 - ( p3 + p7 )  
        plus( t1, p5, p1 );  
        plus( t2, p3, p7 );  
        minus( u, t1, t2 );  
    
        for( i = 0; i < N / 2; i++ )  
        {  
            for( j = 0; j < N / 2; j++ )  
            {  
                mat[i][j] = r[i][j];  
                mat[i][j + N / 2] = s[i][j];  
                mat[i + N / 2][j] = t[i][j];  
                mat[i + N / 2][j + N / 2] = u[i][j];  
            }  
        }  
    
        printf("
    下面是strassen算法处理结果:
    ");  
        for( i = 0; i < N; i++ )  
        {  
            for( j = 0; j < N; j++ )  
            {  
                printf("%d ", mat[i][j]);  
            }  
            printf("
    ");  
        }  
    
        //下面是朴素算法处理  
        printf("
    下面是朴素算法处理结果:
    ");  
        for( i = 0; i < N; i++ )  
        {  
            for( j = 0; j < N; j++ )  
            {  
                mat[i][j] = 0;  
                for( k = 0; k < N; k++ )  
                {  
                    mat[i][j] += m1[i][j] * m2[i][j];  
                }  
            }  
        }  
    
        for( i = 0; i < N; i++ )  
        {  
            for( j = 0; j < N; j++ )  
            {  
                printf("%d ", mat[i][j]);  
            }  
            printf("
    ");  
        }  
    
        return 0;  
    }  

    现在最好的计算矩阵乘法的复杂度是O( n^2.376 ),不过只是理论上的结果。此处仅仅做参考~

  • 相关阅读:
    Python之数学(math)和随机数(random)
    《图解HTTP》读书笔记
    leetcode1008
    leetcode1007
    leetcode1006
    leetcode1005
    leetcode218
    leetcode212
    leetcode149
    leetcode140
  • 原文地址:https://www.cnblogs.com/yangquanhui/p/4937511.html
Copyright © 2011-2022 走看看