zoukankan      html  css  js  c++  java
  • 用java写bp神经网络(二)

    接上篇。

    Net和Propagation具备后,我们就可以训练了。训练师要做的事情就是,怎么把一大批样本分成小批训练,然后把小批的结果合并成完整的结果(批量/增量);什么时候调用学习师根据训练的结果进行学习,然后改进网络的权重和状态;什么时候决定训练结束。

    那么这两位老师儿长的什么样子,又是怎么做到的呢?

    public interface Trainer {
        public void train(Net net,DataProvider provider);
    }
    
    public interface Learner {
        public void learn(Net net,TrainResult trainResult);
    }
    

     所谓Trainer即是给定数据,对指定网络进行训练;所谓Learner即是给定训练结果,然后对指定网络进行权重调整。

    下面给出这两个接口的简单实现。

    Trainer

    Trainer实现简单的批量训练功能,在给定的迭代次数后停止。代码示例如下。

    public class CommonTrainer implements Trainer {
    	int ecophs;
    	Learner learner;
    	List<Double> costs = new ArrayList<>();
    	List<Double> accuracys = new ArrayList<>();
    	int batchSize = 1;
    
    	public CommonTrainer(int ecophs, Learner learner) {
    		super();
    		this.ecophs = ecophs;
    		this.learner = learner == null ? new MomentAdaptLearner() : learner;
    	}
    
    	public CommonTrainer(int ecophs, Learner learner, int batchSize) {
    		this(ecophs, learner);
    		this.batchSize = batchSize;
    	}
    
    	public void trainOne(final Net net, DataProvider provider) {
    		final Propagation propagation = new Propagation(net);
    		DoubleMatrix input = provider.getInput();
    		DoubleMatrix target = provider.getTarget();
    		final int allLen = target.columns;
    		final int[] nodesNum = net.getNodesNum();
    		final int layersNum = net.getLayersNum();
    
    		List<DoubleMatrix> inputBatches = this.getBatches(input);
    		final List<DoubleMatrix> targetBatches = this.getBatches(target);
    
    		final List<Integer> batchLen = MatrixUtil.getEndPosition(targetBatches);
    
    		final BackwardResult backwardResult = new BackwardResult(net, allLen);
    
              // 分批并行训练
    		Parallel.For(inputBatches, new Parallel.Operation<DoubleMatrix>() {
    			@Override
    			public void perform(int index, DoubleMatrix subInput) {
    				ForwardResult subResult = propagation.forward(subInput);
    				DoubleMatrix subTarget = targetBatches.get(index);
    				BackwardResult backResult = propagation.backward(subTarget,
    						subResult);
    
    				DoubleMatrix cost = backwardResult.cost;
    				DoubleMatrix accuracy = backwardResult.accuracy;
    				DoubleMatrix inputDeltas = backwardResult.getInputDelta();
    
    				int start = index == 0 ? 0 : batchLen.get(index - 1);
    				int end = batchLen.get(index) - 1;
    				int[] cIndexs = ArraysHelper.makeArray(start, end);
    
    				cost.put(cIndexs, backResult.cost);
    
    				if (accuracy != null) {
    					accuracy.put(cIndexs, backResult.accuracy);
    				}
    
    				inputDeltas.put(ArraysHelper.makeArray(0, nodesNum[0] - 1),
    						  cIndexs, backResult.getInputDelta());
    
    				for (int i = 0; i < layersNum; i++) {
    					DoubleMatrix gradients = backwardResult.gradients.get(i);
    					DoubleMatrix biasGradients = backwardResult.biasGradients
    							.get(i);
    
      					DoubleMatrix subGradients = backResult.gradients.get(i)
    							.muli(backResult.cost.columns);
    					DoubleMatrix subBiasGradients = backResult.biasGradients
    							.get(i).muli(backResult.cost.columns);
    					gradients.addi(subGradients);
    					biasGradients.addi(subBiasGradients);
    				}
    			}
    		});
             // 求均值
    		for(DoubleMatrix gradient:backwardResult.gradients){
    			gradient.divi(allLen);
    		}
    		for(DoubleMatrix gradient:backwardResult.biasGradients){
    			gradient.divi(allLen);
    		}
    		
    		// this.mergeBackwardResult(backResults, net, input.columns);
    		TrainResult trainResult = new TrainResult(null, backwardResult);
    
    		learner.learn(net, trainResult);
    
    		Double cost = backwardResult.getMeanCost();
    		Double accuracy = backwardResult.getMeanAccuracy();
    		if (cost != null)
    			costs.add(cost);
    		if (accuracy != null)
    			accuracys.add(accuracy);
    
      		System.out.println(cost);
    		System.out.println(accuracy);
    	}
    
    	@Override
    	public void train(Net net, DataProvider provider) {
    		for (int i = 0; i < this.ecophs; i++) {
    			this.trainOne(net, provider);
    		}
    
    	}
    }
    

    Learner

    Learner是具体的调整算法,当梯度计算出来后,它负责对网络权重进行调整。调整算法的选择直接影响着网络收敛的快慢。本文的实现采用简单的动量-自适应学习率算法。

    其迭代公式如下:

    $$W(t+1)=W(t)+Delta W(t)$$

    $$Delta W(t)=rate(t)(1-moment(t))G(t)+moment(t)Delta W(t-1)$$

    $$rate(t+1)=egin{cases} rate(t) imes 1.05 & mbox{if } cost(t)<cost(t-1)\ rate(t) imes 0.7 & mbox{else if } cost(t)<cost(t-1) imes 1.04\ 0.01 & mbox{else} end{cases}$$

    $$moment(t+1)=egin{cases} 0.9 & mbox{if } cost(t)<cost(t-1)\ moment(t) imes 0.7 & mbox{else if } cost(t)<cost(t-1) imes 1.04\ 1-0.9 & mbox{else} end{cases}$$

    示例代码如下:

    public class MomentAdaptLearner implements Learner {
    
    	Net net;
    	double moment = 0.9;
    	double lmd = 1.05;
    	double preCost = 0;
    	double eta = 0.01;
    	double currentEta=eta;
    	double currentMoment=moment;
    	TrainResult preTrainResult;
    	
    	public MomentAdaptLearner(double moment, double eta) {
    		super();
    		this.moment = moment;
    		this.eta = eta;
    		this.currentEta=eta;
    		this.currentMoment=moment;
    	}
    
    	@Override
    	public void learn(Net net, TrainResult trainResult) {
    		if (this.net == null)
    			init(net);
    		
    		BackwardResult backwardResult = trainResult.backwardResult;
    		BackwardResult preBackwardResult = preTrainResult.backwardResult;
    		double cost=backwardResult.getMeanCost();
    		this.modifyParameter(cost);
    		System.out.println("current eta:"+this.currentEta);
    		System.out.println("current moment:"+this.currentMoment);
    		for (int j = 0; j < net.getLayersNum(); j++) {
    			DoubleMatrix weight = net.getWeights().get(j);
    			DoubleMatrix gradient = backwardResult.gradients.get(j);
    
    			gradient = gradient.muli(currentEta * (1 - this.currentMoment)).addi(
    					preBackwardResult.gradients.get(j).muli(this.currentMoment));
    			preBackwardResult.gradients.set(j, gradient);
    
    			weight.subi(gradient);
    			
    			DoubleMatrix b = net.getBs().get(j);
    			DoubleMatrix bgradient = backwardResult.biasGradients.get(j);
    
    			bgradient = bgradient.muli(currentEta * (1 - this.currentMoment)).addi(
    					preBackwardResult.biasGradients.get(j).muli(this.currentMoment));
    			preBackwardResult.biasGradients.set(j, bgradient);
    
    			b.subi(bgradient);
    		}
    
    	}
    
    	public void modifyParameter(double cost){
    		if(cost<this.preCost){
    			this.currentEta*=1.05;
    			this.currentMoment=moment;
    		}else if(cost<1.04*this.preCost){
    			this.currentEta*=0.7;
    			this.currentMoment*=0.7;
    		}else{
    			this.currentEta=eta;
    			this.currentMoment=1-moment;
    		}
    		this.preCost=cost;
    	}
    	public void init(Net net) {
    		this.net =  net;
    		BackwardResult bResult = new BackwardResult();
    
    		for (DoubleMatrix weight : net.getWeights()) {
    			bResult.gradients.add(DoubleMatrix.zeros(weight.rows,
    					weight.columns));
    		}
    		for (DoubleMatrix b : net.getBs()) {
    			bResult.biasGradients.add(DoubleMatrix.zeros(b.rows, b.columns));
    		}
    		preTrainResult=new TrainResult(null,bResult);
    	}
    
    
    }
    

    现在,一个简单的神经网路从生成到训练已经简单实现完毕。

    下一步,使用Levenberg-Marquardt学习算法改进收敛速率。

  • 相关阅读:
    Linux命令应用大词典-第11章 Shell编程
    Kubernetes 学习12 kubernetes 存储卷
    linux dd命令
    Kubernetes 学习11 kubernetes ingress及ingress controller
    Kubernetes 学习10 Service资源
    Kubernetes 学习9 Pod控制器
    Kubernetes 学习8 Pod控制器
    Kubernetes 学习7 Pod控制器应用进阶2
    Kubernetes 学习6 Pod控制器应用进阶
    Kubernetes 学习5 kubernetes资源清单定义入门
  • 原文地址:https://www.cnblogs.com/wuseguang/p/4126226.html
Copyright © 2011-2022 走看看