zoukankan      html  css  js  c++  java
  • 用java写bp神经网络(二)

    接上篇。

    Net和Propagation具备后,我们就可以训练了。训练师要做的事情就是,怎么把一大批样本分成小批训练,然后把小批的结果合并成完整的结果(批量/增量);什么时候调用学习师根据训练的结果进行学习,然后改进网络的权重和状态;什么时候决定训练结束。

    那么这两位老师儿长的什么样子,又是怎么做到的呢?

    public interface Trainer {
        public void train(Net net,DataProvider provider);
    }
    
    public interface Learner {
        public void learn(Net net,TrainResult trainResult);
    }
    

     所谓Trainer即是给定数据,对指定网络进行训练;所谓Learner即是给定训练结果,然后对指定网络进行权重调整。

    下面给出这两个接口的简单实现。

    Trainer

    Trainer实现简单的批量训练功能,在给定的迭代次数后停止。代码示例如下。

    public class CommonTrainer implements Trainer {
    	int ecophs;
    	Learner learner;
    	List<Double> costs = new ArrayList<>();
    	List<Double> accuracys = new ArrayList<>();
    	int batchSize = 1;
    
    	public CommonTrainer(int ecophs, Learner learner) {
    		super();
    		this.ecophs = ecophs;
    		this.learner = learner == null ? new MomentAdaptLearner() : learner;
    	}
    
    	public CommonTrainer(int ecophs, Learner learner, int batchSize) {
    		this(ecophs, learner);
    		this.batchSize = batchSize;
    	}
    
    	public void trainOne(final Net net, DataProvider provider) {
    		final Propagation propagation = new Propagation(net);
    		DoubleMatrix input = provider.getInput();
    		DoubleMatrix target = provider.getTarget();
    		final int allLen = target.columns;
    		final int[] nodesNum = net.getNodesNum();
    		final int layersNum = net.getLayersNum();
    
    		List<DoubleMatrix> inputBatches = this.getBatches(input);
    		final List<DoubleMatrix> targetBatches = this.getBatches(target);
    
    		final List<Integer> batchLen = MatrixUtil.getEndPosition(targetBatches);
    
    		final BackwardResult backwardResult = new BackwardResult(net, allLen);
    
              // 分批并行训练
    		Parallel.For(inputBatches, new Parallel.Operation<DoubleMatrix>() {
    			@Override
    			public void perform(int index, DoubleMatrix subInput) {
    				ForwardResult subResult = propagation.forward(subInput);
    				DoubleMatrix subTarget = targetBatches.get(index);
    				BackwardResult backResult = propagation.backward(subTarget,
    						subResult);
    
    				DoubleMatrix cost = backwardResult.cost;
    				DoubleMatrix accuracy = backwardResult.accuracy;
    				DoubleMatrix inputDeltas = backwardResult.getInputDelta();
    
    				int start = index == 0 ? 0 : batchLen.get(index - 1);
    				int end = batchLen.get(index) - 1;
    				int[] cIndexs = ArraysHelper.makeArray(start, end);
    
    				cost.put(cIndexs, backResult.cost);
    
    				if (accuracy != null) {
    					accuracy.put(cIndexs, backResult.accuracy);
    				}
    
    				inputDeltas.put(ArraysHelper.makeArray(0, nodesNum[0] - 1),
    						  cIndexs, backResult.getInputDelta());
    
    				for (int i = 0; i < layersNum; i++) {
    					DoubleMatrix gradients = backwardResult.gradients.get(i);
    					DoubleMatrix biasGradients = backwardResult.biasGradients
    							.get(i);
    
      					DoubleMatrix subGradients = backResult.gradients.get(i)
    							.muli(backResult.cost.columns);
    					DoubleMatrix subBiasGradients = backResult.biasGradients
    							.get(i).muli(backResult.cost.columns);
    					gradients.addi(subGradients);
    					biasGradients.addi(subBiasGradients);
    				}
    			}
    		});
             // 求均值
    		for(DoubleMatrix gradient:backwardResult.gradients){
    			gradient.divi(allLen);
    		}
    		for(DoubleMatrix gradient:backwardResult.biasGradients){
    			gradient.divi(allLen);
    		}
    		
    		// this.mergeBackwardResult(backResults, net, input.columns);
    		TrainResult trainResult = new TrainResult(null, backwardResult);
    
    		learner.learn(net, trainResult);
    
    		Double cost = backwardResult.getMeanCost();
    		Double accuracy = backwardResult.getMeanAccuracy();
    		if (cost != null)
    			costs.add(cost);
    		if (accuracy != null)
    			accuracys.add(accuracy);
    
      		System.out.println(cost);
    		System.out.println(accuracy);
    	}
    
    	@Override
    	public void train(Net net, DataProvider provider) {
    		for (int i = 0; i < this.ecophs; i++) {
    			this.trainOne(net, provider);
    		}
    
    	}
    }
    

    Learner

    Learner是具体的调整算法,当梯度计算出来后,它负责对网络权重进行调整。调整算法的选择直接影响着网络收敛的快慢。本文的实现采用简单的动量-自适应学习率算法。

    其迭代公式如下:

    $$W(t+1)=W(t)+Delta W(t)$$

    $$Delta W(t)=rate(t)(1-moment(t))G(t)+moment(t)Delta W(t-1)$$

    $$rate(t+1)=egin{cases} rate(t) imes 1.05 & mbox{if } cost(t)<cost(t-1)\ rate(t) imes 0.7 & mbox{else if } cost(t)<cost(t-1) imes 1.04\ 0.01 & mbox{else} end{cases}$$

    $$moment(t+1)=egin{cases} 0.9 & mbox{if } cost(t)<cost(t-1)\ moment(t) imes 0.7 & mbox{else if } cost(t)<cost(t-1) imes 1.04\ 1-0.9 & mbox{else} end{cases}$$

    示例代码如下:

    public class MomentAdaptLearner implements Learner {
    
    	Net net;
    	double moment = 0.9;
    	double lmd = 1.05;
    	double preCost = 0;
    	double eta = 0.01;
    	double currentEta=eta;
    	double currentMoment=moment;
    	TrainResult preTrainResult;
    	
    	public MomentAdaptLearner(double moment, double eta) {
    		super();
    		this.moment = moment;
    		this.eta = eta;
    		this.currentEta=eta;
    		this.currentMoment=moment;
    	}
    
    	@Override
    	public void learn(Net net, TrainResult trainResult) {
    		if (this.net == null)
    			init(net);
    		
    		BackwardResult backwardResult = trainResult.backwardResult;
    		BackwardResult preBackwardResult = preTrainResult.backwardResult;
    		double cost=backwardResult.getMeanCost();
    		this.modifyParameter(cost);
    		System.out.println("current eta:"+this.currentEta);
    		System.out.println("current moment:"+this.currentMoment);
    		for (int j = 0; j < net.getLayersNum(); j++) {
    			DoubleMatrix weight = net.getWeights().get(j);
    			DoubleMatrix gradient = backwardResult.gradients.get(j);
    
    			gradient = gradient.muli(currentEta * (1 - this.currentMoment)).addi(
    					preBackwardResult.gradients.get(j).muli(this.currentMoment));
    			preBackwardResult.gradients.set(j, gradient);
    
    			weight.subi(gradient);
    			
    			DoubleMatrix b = net.getBs().get(j);
    			DoubleMatrix bgradient = backwardResult.biasGradients.get(j);
    
    			bgradient = bgradient.muli(currentEta * (1 - this.currentMoment)).addi(
    					preBackwardResult.biasGradients.get(j).muli(this.currentMoment));
    			preBackwardResult.biasGradients.set(j, bgradient);
    
    			b.subi(bgradient);
    		}
    
    	}
    
    	public void modifyParameter(double cost){
    		if(cost<this.preCost){
    			this.currentEta*=1.05;
    			this.currentMoment=moment;
    		}else if(cost<1.04*this.preCost){
    			this.currentEta*=0.7;
    			this.currentMoment*=0.7;
    		}else{
    			this.currentEta=eta;
    			this.currentMoment=1-moment;
    		}
    		this.preCost=cost;
    	}
    	public void init(Net net) {
    		this.net =  net;
    		BackwardResult bResult = new BackwardResult();
    
    		for (DoubleMatrix weight : net.getWeights()) {
    			bResult.gradients.add(DoubleMatrix.zeros(weight.rows,
    					weight.columns));
    		}
    		for (DoubleMatrix b : net.getBs()) {
    			bResult.biasGradients.add(DoubleMatrix.zeros(b.rows, b.columns));
    		}
    		preTrainResult=new TrainResult(null,bResult);
    	}
    
    
    }
    

    现在,一个简单的神经网路从生成到训练已经简单实现完毕。

    下一步,使用Levenberg-Marquardt学习算法改进收敛速率。

  • 相关阅读:
    【转】PCA for opencv
    【转】PCA算法学习_1(OpenCV中PCA实现人脸降维)
    从输入URL到页面渲染完成(转)
    前端面试笔记(整理)
    这样“断舍离”,你会活得更高级
    angular5.x全局loading解决方法
    angular路由守卫
    优化回流和重绘
    回流 (Reflow)和重绘 (Repaint)
    Javascript获取数组中最大和最小值
  • 原文地址:https://www.cnblogs.com/wuseguang/p/4126226.html
Copyright © 2011-2022 走看看