zoukankan      html  css  js  c++  java
  • 用java写bp神经网络(三)

    孔子曰,吾日三省吾身。我们如果跟程序打交道,除了一日三省吾身外,还要三日一省吾代码。看代码是否可以更简洁,更易懂,更容易扩展,更通用,算法是否可以再优化,结构是否可以再往上抽象。代码在不断的重构过程中,更臻化境。佝偻者承蜩如是,大匠铸剑亦复如是,艺虽小,其道一也。所谓苟日新,再日新,日日新。

    本次对前两篇文章代码进行重构,主要重构函数接口体系,和权重矩阵的封装。

    简单函数

    所说函数,是数学概念上的函数。数学上的函数,一般有一自变量$x$(输入)和对应的值$y=f(x)$(输出)。其中$x$可以是个数字,一个向量,一个矩阵等等。我们用泛型定义如下:

    public interface Function<I,O> {
      O valueAt(I x);
    }
    

     I代表输入类型,O代表输出类型。

    有的函数是可微的,比如神经网络的激活函数。可微函数除了是一个函数,还可求出给定$x$处的导数,或者梯度。而且梯度类型与自变量类型一致。用泛型定义如下:

    public interface DifferentiableFunction<I,O> extends Function<I,O> {
      I derivativeAt(I x);
    }
    

    同时,考虑到某些函数,在求得值和导数时,共同用到了一些中间变量,或者后一个可以用到前一个的结果,我们定义了PreCaculate接口。当我们判定一个函数实现了PreCaculate接口时,我们首先调用它的PreCaculate接口,让它预先计算出一些有用的中间变量,然后再调用其valueAt和derivativeAt求得其具体的值,这样可以节省一些操作步骤。定义如下:

    public interface PreCaculate<I> {
    	void preCaculate(I x);
    }
    

     基于上面的定义,我们定义神经网络的激活函数的类型为:

    public interface ActivationFunction extends DifferentiableFunction<DoubleMatrix, DoubleMatrix>
    

     即我们激活函数是一个可微函数,输入为一个矩阵(netResult),输出为一个矩阵(finalResult)。

    带参函数

    有些函数,除了自变量外,还有一些其它的系数,或者参数,我们称为超参数。比如误差函数,目标值为参数,输出值为自变量。这类函数接口定义如下:

    public interface ParamFunction<I,O,P> {
    	O valueAt(I x,P param);
    }
    

     类似的,定义其微分接口如下:

    public interface DifferentiableParamFunction<I, O, P> extends ParamFunction<I, O, P> {
    	I derivativeAt(I x,P param);
    }
    

     我们的误差函数定义如下:

    public interface CostFunction extends DifferentiableParamFunction<DoubleMatrix,DoubleMatrix,DoubleMatrix>
    

     输入,输出,参数都为矩阵。

    组合矩阵

    在神经网络的概念中,每两层之间有一个权重矩阵,偏置矩阵,如果输入字向量也要调整,那么还有一个字典矩阵。这些所有的矩阵随着迭代过程不断更新,以期使误差函数达到最小。从广义上来讲,训练样本就是超参数,这些所有的矩阵为自变量,误差函数就是优化函数。那么实质上,在调整权重矩阵时,自变量即这一系列的矩阵可以展开拉长拼接成一个超长的向量而已,其内部的结构已无关紧要。在jare的源码中,是把这些权重矩阵的值存储在一个长的double[]中,计算完毕后,再从这个doulbe[]中还原出各矩阵的结构。在这里,我们定义了一个类CompactDoubleMatrix名为超矩阵来从更高一层封装这些矩阵变量,使其对外表现出好像就是一个矩阵。

    这个CompactDoubleMatrix的实现方式为,在内部维护一个DoubleMatrix的有序列表List<DoubleMatrix>,然后再执行加减乘除操作时,会批量的对列表中的所有矩阵执行。这样的封装,我们随后会发现将简化了我们大量代码。先把完整定义放上来。

    public class CompactDoubleMatrix {
    	List<DoubleMatrix> mats = new ArrayList<DoubleMatrix>();
    
    	@SafeVarargs
    	public CompactDoubleMatrix(List<DoubleMatrix>... matListArray) {
    		super();
    		this.append(matListArray);
    	}
    
    	public CompactDoubleMatrix(DoubleMatrix... matArray) {
    		super();
    		this.append(matArray);
    	}
    
    	public CompactDoubleMatrix() {
    		super();
    	}
    
    	public CompactDoubleMatrix addi(CompactDoubleMatrix other) {
    		this.assertSize(other);
    		for (int i = 0; i < this.length(); i++)
    			this.get(i).addi(other.get(i));
    		return this;
    	}
    
    	public void subi(CompactDoubleMatrix other) {
    		this.assertSize(other);
    		for (int i = 0; i < this.length(); i++)
    			this.get(i).subi(other.get(i));
    	}
    
    	public CompactDoubleMatrix add(CompactDoubleMatrix other) {
    		this.assertSize(other);
    		CompactDoubleMatrix result = new CompactDoubleMatrix();
    		for (int i = 0; i < this.length(); i++) {
    			result.append(this.get(i).add(other.get(i)));
    		}
    		return result;
    	}
    
    	public CompactDoubleMatrix sub(CompactDoubleMatrix other) {
    		this.assertSize(other);
    		CompactDoubleMatrix result = new CompactDoubleMatrix();
    		for (int i = 0; i < this.length(); i++) {
    			result.append(this.get(i).sub(other.get(i)));
    		}
    		return result;
    	}
    
    	public CompactDoubleMatrix mul(CompactDoubleMatrix other) {
    		this.assertSize(other);
    		CompactDoubleMatrix result = new CompactDoubleMatrix();
    		for (int i = 0; i < this.length(); i++) {
    			result.append(this.get(i).mul(other.get(i)));
    		}
    		return result;
    	}
    
    	public CompactDoubleMatrix muli(double d) {
    
    		for (int i = 0; i < this.length(); i++) {
    			this.get(i).muli(d);
    		}
    		return this;
    	}
    
    	public CompactDoubleMatrix mul(double d) {
    		CompactDoubleMatrix result = new CompactDoubleMatrix();
    		for (int i = 0; i < this.length(); i++) {
    			result.append(this.get(i).mul(d));
    		}
    		return result;
    	}
    
    	public CompactDoubleMatrix dup() {
    		CompactDoubleMatrix result = new CompactDoubleMatrix();
    		for (int i = 0; i < this.length(); i++) {
    			result.append(this.get(i).dup());
    		}
    		return result;
    	}
    
    	public double dot(CompactDoubleMatrix other) {
    		double sum = 0;
    		for (int i = 0; i < this.length(); i++) {
    			sum += this.get(i).dot(other.get(i));
    		}
    		return sum;
    	}
    
    	public double norm() {
    		double sum = 0;
    		for (int i = 0; i < this.length(); i++) {
    			double subNorm = this.get(i).norm2();
    			sum += subNorm * subNorm;
    		}
    		return Math.sqrt(sum);
    	}
    
    	public void assertSize(CompactDoubleMatrix other) {
    		assert (other != null && this.length() == other.length());
    		for (int i = 0; i < this.length(); i++) {
    			assert (this.get(i).sameSize(other.get(i)));
    		}
    	}
    
    	@SuppressWarnings("unchecked")
    	public void append(List<DoubleMatrix>... matListArray) {
    		for (List<DoubleMatrix> list : matListArray) {
    			this.mats.addAll(list);
    		}
    	}
    
    	public void append(DoubleMatrix... matArray) {
    		for (DoubleMatrix mat : matArray)
    			this.mats.add(mat);
    	}
    
    	public int length() {
    		return mats.size();
    	}
    
    	public DoubleMatrix get(int index) {
    		return this.mats.get(index);
    	}
    
    	public DoubleMatrix getLast() {
    		return this.mats.get(this.length() - 1);
    	}
    }
    

     以上介绍了对各抽象概念的封装,下章介绍使用这些封装如何简化我们的代码。

  • 相关阅读:
    父子进程 signal 出现 Interrupted system call 问题
    一个测试文章
    《淘宝客户端 for Android》项目实战 html webkit android css3
    Django 中的 ForeignKey ContentType GenericForeignKey 对应的数据库结构
    coreseek 出现段错误和Unigram dictionary load Error 新情况(Gentoo)
    一个 PAM dbus 例子
    漫画统计学 T分数
    解决 paramiko 安装问题 Unable to find vcvarsall.bat
    20141202
    js
  • 原文地址:https://www.cnblogs.com/wuseguang/p/4140218.html
Copyright © 2011-2022 走看看