zoukankan      html  css  js  c++  java
  • c++的矩阵乘法加速trick

    最近读RNNLM的源代码,发现其实现矩阵乘法时使用了一个trick,这里描述一下这个trick。

    首先是正常版的矩阵乘法(其实是矩阵乘向量)

    void matrixXvector(float* destvect, float* srcmatrix, int srcmatrix_rownum, int srcmatrix_colnum, float* srcvect, int srcvect_size){
    	for(int row=0;row<srcmatrix_rownum;++row){
    		destvect[row]=0;
    		for(int col=0;col<srcmatrix_colnum;++col){
    			destvect[row]+=srcmatrix[row*srcmatrix_colnum+col]*srcvect[col];
    		}
    	}
    }
    
    

    就是最简单的for循环,逐行逐列遍历。

    接下来是RNNLM中实现的trick版本

    void matrixXvector2(float* destvect, float* srcmatrix, int srcmatrix_rownum, int srcmatrix_colnum, float* srcvect, int srcvect_size){
    	int row, col;
    	float val1, val2, val3, val4;
    	float val5, val6, val7, val8;
    	
    	for(row=0;row<srcmatrix_rownum/8;++row){
    		val1 = 0;
    		val2 = 0;
    		val3 = 0;
    		val4 = 0;
    		val5 = 0;
    		val6 = 0;
    		val7 = 0;
    		val8 = 0;
    		
    		for(col=0;col<srcmatrix_colnum;++col){
    			val1+=srcmatrix[(row*8+0)*srcmatrix_colnum+col]*srcvect[col];
    			val2+=srcmatrix[(row*8+1)*srcmatrix_colnum+col]*srcvect[col];
    			val3+=srcmatrix[(row*8+2)*srcmatrix_colnum+col]*srcvect[col];
    			val4+=srcmatrix[(row*8+3)*srcmatrix_colnum+col]*srcvect[col];
    			val5+=srcmatrix[(row*8+4)*srcmatrix_colnum+col]*srcvect[col];
    			val6+=srcmatrix[(row*8+5)*srcmatrix_colnum+col]*srcvect[col];
    			val7+=srcmatrix[(row*8+6)*srcmatrix_colnum+col]*srcvect[col];
    			val8+=srcmatrix[(row*8+7)*srcmatrix_colnum+col]*srcvect[col];
    		}
    		
    		destvect[row*8+0]+=val1;
    		destvect[row*8+1]+=val2;
    		destvect[row*8+2]+=val3;
    		destvect[row*8+3]+=val4;
    		destvect[row*8+4]+=val5;
    		destvect[row*8+5]+=val6;
    		destvect[row*8+6]+=val7;
    		destvect[row*8+7]+=val8;
    		
    	}
    	
    	for(row=row*8;row<srcmatrix_rownum;++row){
    		for(col=0;col<srcmatrix_colnum;++col){
    			destvect[row]+=srcmatrix[row*srcmatrix_colnum+col]*srcvect[col];	
    		}
    	}
    }
    
    

    对比普通版,trick版把遍历行的for循环分成了8份,同时进行列遍历。

    实际测试中,这个trick版比普通版快了接近2倍~这是编译器优化造成的么……?

  • 相关阅读:
    apache 虚拟主机配置(根据不同的域名映射到不同网站)
    Tortoise SVN 使用笔记
    Apache 根据不同的端口 映射不同的站点
    jquery 获取当前元素的索引值
    修改ThinkPHP的验证码类
    NetBeans无法使用编码GBK安全地打开该文件
    在win2003下apache2.2无法加载php5apache2_4.dll
    我看软件工程
    PHP函数参数传递(相对于C++的值传递和引用传递)
    Notepad++ 使用正则表达式查找替换字符串
  • 原文地址:https://www.cnblogs.com/plwang1990/p/4139357.html
Copyright © 2011-2022 走看看