zoukankan      html  css  js  c++  java
  • 用java写bp神经网络(三)

    孔子曰,吾日三省吾身。我们如果跟程序打交道,除了一日三省吾身外,还要三日一省吾代码。看代码是否可以更简洁,更易懂,更容易扩展,更通用,算法是否可以再优化,结构是否可以再往上抽象。代码在不断的重构过程中,更臻化境。佝偻者承蜩如是,大匠铸剑亦复如是,艺虽小,其道一也。所谓苟日新,再日新,日日新。

    本次对前两篇文章代码进行重构,主要重构函数接口体系,和权重矩阵的封装。

    简单函数

    所说函数,是数学概念上的函数。数学上的函数,一般有一自变量$x$(输入)和对应的值$y=f(x)$(输出)。其中$x$可以是个数字,一个向量,一个矩阵等等。我们用泛型定义如下:

    public interface Function<I,O> {
      O valueAt(I x);
    }
    

     I代表输入类型,O代表输出类型。

    有的函数是可微的,比如神经网络的激活函数。可微函数除了是一个函数,还可求出给定$x$处的导数,或者梯度。而且梯度类型与自变量类型一致。用泛型定义如下:

    public interface DifferentiableFunction<I,O> extends Function<I,O> {
      I derivativeAt(I x);
    }
    

    同时,考虑到某些函数,在求得值和导数时,共同用到了一些中间变量,或者后一个可以用到前一个的结果,我们定义了PreCaculate接口。当我们判定一个函数实现了PreCaculate接口时,我们首先调用它的PreCaculate接口,让它预先计算出一些有用的中间变量,然后再调用其valueAt和derivativeAt求得其具体的值,这样可以节省一些操作步骤。定义如下:

    public interface PreCaculate<I> {
    	void preCaculate(I x);
    }
    

     基于上面的定义,我们定义神经网络的激活函数的类型为:

    public interface ActivationFunction extends DifferentiableFunction<DoubleMatrix, DoubleMatrix>
    

     即我们激活函数是一个可微函数,输入为一个矩阵(netResult),输出为一个矩阵(finalResult)。

    带参函数

    有些函数,除了自变量外,还有一些其它的系数,或者参数,我们称为超参数。比如误差函数,目标值为参数,输出值为自变量。这类函数接口定义如下:

    public interface ParamFunction<I,O,P> {
    	O valueAt(I x,P param);
    }
    

     类似的,定义其微分接口如下:

    public interface DifferentiableParamFunction<I, O, P> extends ParamFunction<I, O, P> {
    	I derivativeAt(I x,P param);
    }
    

     我们的误差函数定义如下:

    public interface CostFunction extends DifferentiableParamFunction<DoubleMatrix,DoubleMatrix,DoubleMatrix>
    

     输入,输出,参数都为矩阵。

    组合矩阵

    在神经网络的概念中,每两层之间有一个权重矩阵,偏置矩阵,如果输入字向量也要调整,那么还有一个字典矩阵。这些所有的矩阵随着迭代过程不断更新,以期使误差函数达到最小。从广义上来讲,训练样本就是超参数,这些所有的矩阵为自变量,误差函数就是优化函数。那么实质上,在调整权重矩阵时,自变量即这一系列的矩阵可以展开拉长拼接成一个超长的向量而已,其内部的结构已无关紧要。在jare的源码中,是把这些权重矩阵的值存储在一个长的double[]中,计算完毕后,再从这个doulbe[]中还原出各矩阵的结构。在这里,我们定义了一个类CompactDoubleMatrix名为超矩阵来从更高一层封装这些矩阵变量,使其对外表现出好像就是一个矩阵。

    这个CompactDoubleMatrix的实现方式为,在内部维护一个DoubleMatrix的有序列表List<DoubleMatrix>,然后再执行加减乘除操作时,会批量的对列表中的所有矩阵执行。这样的封装,我们随后会发现将简化了我们大量代码。先把完整定义放上来。

    public class CompactDoubleMatrix {
    	List<DoubleMatrix> mats = new ArrayList<DoubleMatrix>();
    
    	@SafeVarargs
    	public CompactDoubleMatrix(List<DoubleMatrix>... matListArray) {
    		super();
    		this.append(matListArray);
    	}
    
    	public CompactDoubleMatrix(DoubleMatrix... matArray) {
    		super();
    		this.append(matArray);
    	}
    
    	public CompactDoubleMatrix() {
    		super();
    	}
    
    	public CompactDoubleMatrix addi(CompactDoubleMatrix other) {
    		this.assertSize(other);
    		for (int i = 0; i < this.length(); i++)
    			this.get(i).addi(other.get(i));
    		return this;
    	}
    
    	public void subi(CompactDoubleMatrix other) {
    		this.assertSize(other);
    		for (int i = 0; i < this.length(); i++)
    			this.get(i).subi(other.get(i));
    	}
    
    	public CompactDoubleMatrix add(CompactDoubleMatrix other) {
    		this.assertSize(other);
    		CompactDoubleMatrix result = new CompactDoubleMatrix();
    		for (int i = 0; i < this.length(); i++) {
    			result.append(this.get(i).add(other.get(i)));
    		}
    		return result;
    	}
    
    	public CompactDoubleMatrix sub(CompactDoubleMatrix other) {
    		this.assertSize(other);
    		CompactDoubleMatrix result = new CompactDoubleMatrix();
    		for (int i = 0; i < this.length(); i++) {
    			result.append(this.get(i).sub(other.get(i)));
    		}
    		return result;
    	}
    
    	public CompactDoubleMatrix mul(CompactDoubleMatrix other) {
    		this.assertSize(other);
    		CompactDoubleMatrix result = new CompactDoubleMatrix();
    		for (int i = 0; i < this.length(); i++) {
    			result.append(this.get(i).mul(other.get(i)));
    		}
    		return result;
    	}
    
    	public CompactDoubleMatrix muli(double d) {
    
    		for (int i = 0; i < this.length(); i++) {
    			this.get(i).muli(d);
    		}
    		return this;
    	}
    
    	public CompactDoubleMatrix mul(double d) {
    		CompactDoubleMatrix result = new CompactDoubleMatrix();
    		for (int i = 0; i < this.length(); i++) {
    			result.append(this.get(i).mul(d));
    		}
    		return result;
    	}
    
    	public CompactDoubleMatrix dup() {
    		CompactDoubleMatrix result = new CompactDoubleMatrix();
    		for (int i = 0; i < this.length(); i++) {
    			result.append(this.get(i).dup());
    		}
    		return result;
    	}
    
    	public double dot(CompactDoubleMatrix other) {
    		double sum = 0;
    		for (int i = 0; i < this.length(); i++) {
    			sum += this.get(i).dot(other.get(i));
    		}
    		return sum;
    	}
    
    	public double norm() {
    		double sum = 0;
    		for (int i = 0; i < this.length(); i++) {
    			double subNorm = this.get(i).norm2();
    			sum += subNorm * subNorm;
    		}
    		return Math.sqrt(sum);
    	}
    
    	public void assertSize(CompactDoubleMatrix other) {
    		assert (other != null && this.length() == other.length());
    		for (int i = 0; i < this.length(); i++) {
    			assert (this.get(i).sameSize(other.get(i)));
    		}
    	}
    
    	@SuppressWarnings("unchecked")
    	public void append(List<DoubleMatrix>... matListArray) {
    		for (List<DoubleMatrix> list : matListArray) {
    			this.mats.addAll(list);
    		}
    	}
    
    	public void append(DoubleMatrix... matArray) {
    		for (DoubleMatrix mat : matArray)
    			this.mats.add(mat);
    	}
    
    	public int length() {
    		return mats.size();
    	}
    
    	public DoubleMatrix get(int index) {
    		return this.mats.get(index);
    	}
    
    	public DoubleMatrix getLast() {
    		return this.mats.get(this.length() - 1);
    	}
    }
    

     以上介绍了对各抽象概念的封装,下章介绍使用这些封装如何简化我们的代码。

  • 相关阅读:
    VirtualBox Linux服务vboxservicetemplate
    oracle 11g常用命令
    haproxy dataplaneapi
    使用jproflier 分析dremio
    cube.js 支持oceanbase 的mysql driver
    fastdfs 集群异常修复实践
    使用jHiccup 分析java 应用性能
    dremio mysql arp 扩展
    cube.js graphql 支持
    apache kyuubi 参考架构集成
  • 原文地址:https://www.cnblogs.com/wuseguang/p/4140218.html
Copyright © 2011-2022 走看看