zoukankan      html  css  js  c++  java
  • 小矩阵相乘效率对比:lapack, cblas, 手写函数

    我们需要做很多很多小矩阵相乘(维数只有几十),但是次数很多,所以用哪个矩阵库的函数对我们很重要。这里写一个很小的测试代码,测试lapack(包含着朴素的blas),cblas,还有手写函数,对比它们做小矩阵相乘的效率。
    对于给定的维数,这三种办法,每种都做1000次方阵相乘\(AB = C\),每次相乘用的矩阵 \(A,B\) 都是随机的。计时用的是 clock(),取的是 cpu 时间。

    #include<iostream>
    using namespace std;
    #include<fstream>
    #include<cmath>
    #include<vector>
    #include<complex>
    
    #include "mkl.h"
    
    extern "C" void dgemm_(char *TRANSA, char *TRANSB, int *M, int *N, int *K, double* ALPHA, double *A, int* LDA, double *B, int* LDB, double* BETA, double *C, int* LDC);
    
    /*
     * wraps dgemm_() in lapack, uses one of the optional modes of it to do C := A B
     * int n                dimension
     * double * A           A[ n*n ]
     * double * B           B[ n*n ]
     * double * C           C[ n*n ]
     */
    void lapack_dgemm( int n, double * A, double * B, double * C ){
    
            // dgemm (... ) : C = alpha * op( A ) * op( B ) + beta * C
            char TRANSA='N'; // op( A ) = A
            char TRANSB='N'; // op( B ) = B
            int M=n; // number of rows in A
            int N=n; // number of columns in B
            int K=n; // number of columns in A, also equals number of rows in B
            double ALPHA=1.0; // alpha
            double BETA=0.0; // beta
            int LDA=n; // leading dimension of A
            int LDB=n; // leading dimension of B
            int LDC=n; // leading dimension of C
    
            dgemm_(&TRANSA, &TRANSB, &M, &N, &K, &ALPHA, B, &LDA, A, &LDB, &BETA, C, &LDC);
            // because dgemm is written in fortran, it actually gets B^\top A^\top = ( AB )^\top, an (AB)^\top will actually be stored in fortran manner, that is AB in C++
    }
    
    void mtx_multiply( int n, double * A, double * B, double * C ){
    	double y;
    	for(int i=0;i<n;i++){
    		for(int j=0;j<n;j++){
    			y = 0;
    			for(int k=0;k<n;k++) y += A[i*n+k] * B[k*n+j];
    			C[i*n+j] = y;
    		}
    	}
    };
    
    void cmtx_multiply( int n, complex<double> * cA, complex<double> * cB, complex<double> * cC ){
    	
    	//#pragma omp parallel for
    	for(int i=0;i<n;i++){
    		complex<double> y;
    		for(int j=0;j<n;j++){
    			y = 0;
    			for(int k=0;k<n;k++) y += cA[i*n+k] * cB[k*n+j];
    			cC[i*n+j] = y;
    		}
    	}
    };
    
    void cblaszgemm3m( int n, complex<double> * A, complex<double> * B, complex<double> * C ){
    	complex<double> alpha = {1,0}, beta = {0,0};
    	cblas_zgemm3m( CblasRowMajor, CblasNoTrans, CblasNoTrans, n, n, n, &alpha, A, n, B, n, &beta, C, n );
    }
    
    void cblaszgemm( int n, complex<double> * A, complex<double> * B, complex<double> * C ){
    	complex<double> alpha = {1,0}, beta = {0,0};
    	cblas_zgemm( CblasRowMajor, CblasNoTrans, CblasNoTrans, n, n, n, &alpha, A, n, B, n, &beta, C, n );
    }
    
    void printmtx( int n, complex<double> * A ){
    	for(int i=0;i<n;i++){
    		for(int j=0;j<n;j++)cout<< A[i*n+j]<<", ";
    		cout<<endl;
    	}
    }
    
    void randcmtx( int n, complex<double> * A ){
    	for(int i=0;i<n*n;i++) A[i] = ((double)rand())/RAND_MAX;
    }
    void randmtx( int n, double * A ){
    	for(int i=0;i<n*n;i++) A[i] = ((double)rand())/RAND_MAX;
    }
    
    int main(){
    
    	/*
    	// test: A = [ 0, 1, 0, 0 ], B = [ 0, 1, -1, 0 ]
    	// AB = [ -1, 0, 0, 0 ], A^T B^T = [ 0, 0, 0, -1 ]
    	int n = 2;
    	double A[4] = { 0, 1, 0, 0 };
    	double B[4] = { 0, 1, -1, 0 };
    	double C[4];
    
    	lapack_dgemm( n, A, B, C );
    
    	cout<<"C: "; for(int i=0;i<4;i++) cout<<C[i]<<","; cout<<endl;
    	*/
    
    	vector<int> dim = {10, 20, 30, 40  };
    	vector<double> ave_t_lapack_dgemm;
    	vector<double> ave_t_cblas_dgemm;
    	vector<double> ave_t_hand_dgemm;
    	vector<double> ave_t_cblas_zgemm3m;
    	vector<double> ave_t_cblas_zgemm;
    	vector<double> ave_t_hand_zgemm;
    	vector<double> ratio;
    
    	int nrepeat = 1000; double x;
    
    	for(auto n : dim ){
    		cout<<" n = "<<n<<endl;
    		double * A = new double [ n*n ];
    		for(int i=0;i<n*n;i++) A[i] = ((double)rand())/RAND_MAX;
    		double * B = new double [ n*n ];
    		for(int i=0;i<n*n;i++) B[i] = ((double)rand())/RAND_MAX;
    		double * C = new double [ n*n ];
    
    		clock_t t1, t2, t3, t4;
    
    		double alpha = 1, beta = 0;
    		double t_lapack_dgemm = 0, t_cblas_dgemm = 0, t_hand_dgemm = 0;
    		for(int i=0;i<nrepeat;i++){
    			randmtx( n, A ); randmtx( n, B );
    			t1 = clock(); lapack_dgemm( n, A, B, C ); t2 = clock(); t_lapack_dgemm += (t2-t1);
    			t1 = clock(); 
    			cblas_dgemm( CblasRowMajor, CblasNoTrans, CblasNoTrans, n, n, n, alpha, A, n, B, n, beta, C, n );
    			t2 = clock(); t_cblas_dgemm += (t2-t1);
    			t1 = clock(); mtx_multiply( n, A, B, C ); t2 = clock(); t_hand_dgemm += (t2-t1);
    		}
    
    		x = t_lapack_dgemm/CLOCKS_PER_SEC/nrepeat;	
    		cout<<" lapack dgemm:  " << x <<" s."<<endl;
    		ave_t_lapack_dgemm.push_back( x );
    
    		x = t_cblas_dgemm/CLOCKS_PER_SEC/nrepeat;
    		cout<<" cblas dgemm: "<< x <<" s."<<endl;
    		ave_t_cblas_dgemm.push_back( x );
    
    		x = t_hand_dgemm/CLOCKS_PER_SEC/nrepeat;
    		cout<<" hand written gemm: " << x << " s."<<endl;
    		ave_t_hand_dgemm.push_back( x );
    
    		complex<double> * cA = new complex<double> [ n*n ];
    		complex<double> * cB = new complex<double> [ n*n ];
    		complex<double> * cC = new complex<double> [ n*n ];
    
    		double t_cblas_zgemm3m = 0, t_cblas_zgemm = 0, t_hand_zgemm = 0;
    		for(int i=0;i<nrepeat;i++){
    			randcmtx(n, cA); randcmtx(n, cB);
    			t1 = clock(); cblaszgemm3m( n, cA, cB, cC ); t2 = clock(); t_cblas_zgemm3m += (t2-t1);
    			t1 = clock(); cblaszgemm( n, cA, cB, cC ); t2 = clock(); t_cblas_zgemm += (t2-t1);
    			t1 = clock(); cmtx_multiply( n, cA, cB, cC ); t2 = clock(); t_hand_zgemm += (t2-t1);
    		}
    		x = t_cblas_zgemm3m /CLOCKS_PER_SEC/nrepeat;
    		cout<<" cblas zgemm3m: " << x <<" s."<<endl;
    		ave_t_cblas_zgemm3m.push_back( x );
    
    		x = t_cblas_zgemm /CLOCKS_PER_SEC/nrepeat;
    		cout<<" cblas zgemm: " << x <<" s."<<endl;
    		ave_t_cblas_zgemm.push_back( x );
    
    		x = t_hand_zgemm / CLOCKS_PER_SEC/nrepeat;
    		cout<<" hand zgemm: " << x <<" s."<<endl;
    		ave_t_hand_zgemm.push_back( x );
    
    		delete [] A; delete [] B; delete [] C;
    		delete [] cA; delete [] cB; delete [] cC;
    	}
    
    	cout<<" ave_t_lapack_dgemm = [ "; for(auto t : ave_t_lapack_dgemm) cout<<t<<", "; cout<<"]\n";
    	cout<<" ave_t_cblas_dgemm = [ "; for(auto t : ave_t_cblas_dgemm) cout<<t<<", "; cout<<"]\n";
    	cout<<" ave_t_hand_dgemm = [ "; for(auto t : ave_t_hand_dgemm) cout<<t<<", "; cout<<"]\n";
    	cout<<" ave_t_cblas_zgemm3m = [ "; for(auto t : ave_t_cblas_zgemm3m) cout<<t<<", "; cout<<"]\n";
    	cout<<" ave_t_cblas_zgemm = [ "; for(auto t : ave_t_cblas_zgemm) cout<<t<<", "; cout<<"]\n";
    	cout<<" ave_t_hand_zgemm = [ "; for(auto t : ave_t_hand_zgemm) cout<<t<<", "; cout<<"]\n";
    
    	return 0;
    }
    

    编译:

    icc gemm.cpp -qmkl -lblas -lgsl -O3
    

    运行:

    ./a.out
    

    做出来的结果如下
    image

    结论是 cblas 比 朴素的 blas或者手写函数都要强(of course)。
    但实践中有几个点我要记一下:

    • 编译时如果不开 -O3,cblas 很慢,在 n=10, 20 时不如手写
    • 如果不是在代码中运行 1000 次取平均,只跑一次进行比较的话,cblas 在 n=10,20,30 也不如手写函数。这个我不是完全理解,但考虑到实践中是密集的矩阵运算,所以运行1000次取平均似乎更接近实际场景。在实践中用 cblas 确实也比手写更快,PVPC Si28 中用 zgemm3m 比用手写函数要耗时少25%,用手写函数要 16s, 用 zgemm3m 要 12s。所以暂时不纠结这个问题了,先用着 cblas。
    • 图中zgemm 似乎比 zgemm3m 还快一点,实践中也得到了印证,在 PVPC Si28 中,用 zgemm 只要 10s。
  • 相关阅读:
    操作系统复习
    Google hack语法
    c++的set重载运算符
    华为笔试题
    Flume+Kafka整合
    kafka相关知识点总结
    kafka中生产者和消费者API
    Kafka集群环境搭建
    Storm消息容错机制(ack-fail机制)
    Storm通信机制(了解)
  • 原文地址:https://www.cnblogs.com/luyi07/p/15562722.html
Copyright © 2011-2022 走看看