zoukankan      html  css  js  c++  java
  • 用java写bp神经网络(四)

    接上篇。

    在(一)和(二)中,程序的体系是Net,Propagation,Trainer,Learner,DataProvider。这篇重构这个体系。

    Net

    首先是Net,在上篇重新定义了激活函数和误差函数后,内容大致是这样的:

    List<DoubleMatrix> weights = new ArrayList<DoubleMatrix>();
    	List<DoubleMatrix> bs = new ArrayList<>();
    	List<ActivationFunction> activations = new ArrayList<>();
    	CostFunction costFunc;
    	CostFunction accuracyFunc;
    	int[] nodesNum;
    	int layersNum;
    
    public CompactDoubleMatrix getCompact(){
    		return new CompactDoubleMatrix(this.weights,this.bs);
    	}
    

     函数getCompact()生成对应的超矩阵。

    DataProvider

    DataProvider是数据的提供者。

    public interface DataProvider {
        DoubleMatrix getInput();
        DoubleMatrix getTarget();
    }
    

     如果输入为向量,还包含一个向量字典。

    public interface DictDataProvider extends DataProvider {
    	public DoubleMatrix getIndexs();
    	public DoubleMatrix getDict();
    }
    

     每一列为一个样本。getIndexs()返回输入向量在字典中的索引。

    我写了一个有用的类BatchDataProviderFactory来对样本进行批量分割,分割成minibatch。

    int batchSize;
    	int dataLen;
    	DataProvider originalProvider;
    	List<Integer> endPositions;
    	List<DataProvider> providers;
    
    	public BatchDataProviderFactory(int batchSize, DataProvider originalProvider) {
    		super();
    		this.batchSize = batchSize;
    		this.originalProvider = originalProvider;
    		this.dataLen = this.originalProvider.getTarget().columns;
    		this.initEndPositions();
    		this.initProviders();
    	}
    
    	public BatchDataProviderFactory(DataProvider originalProvider) {
    		this(4, originalProvider);
    	}
    
    	public List<DataProvider> getProviders() {
    		return providers;
    	}
    

     batchSize指明要分多少批,getProviders返回生成的minibatch,被分的原始数据为originalProvider。

    Propagation

    Propagation负责对神经网络的正向传播过程和反向传播过程。接口定义如下:

    public interface Propagation {
    	public PropagationResult propagate(Net net,DataProvider provider);
    }
    

     传播函数propagate用指定数据对指定网络进行传播操作,返回执行结果。

    BasePropagation实现了该接口,实现了简单的反向传播:

    public class BasePropagation implements Propagation{
    
    	// 多个样本。
    	protected ForwardResult forward(Net net,DoubleMatrix input) {
    		
    		ForwardResult result = new ForwardResult();
    		result.input = input;
    		DoubleMatrix currentResult = input;
    		int index = -1;
    		for (DoubleMatrix weight : net.weights) {
    			index++;
    			DoubleMatrix b = net.bs.get(index);
    			final ActivationFunction activation = net.activations
    					.get(index);
    			currentResult = weight.mmul(currentResult).addColumnVector(b);
    			result.netResult.add(currentResult);
    
    			// 乘以导数
    			DoubleMatrix derivative = activation.derivativeAt(currentResult);
    			result.derivativeResult.add(derivative);
    			
    			currentResult = activation.valueAt(currentResult);
    			result.finalResult.add(currentResult);
    
    		}
    
    		result.netResult=null;// 不再需要。
    		
    		return result;
    	}
    
    	
    
        // 多个样本梯度平均值。
    	protected BackwardResult backward(Net net,DoubleMatrix target,
    			ForwardResult forwardResult) {
    		BackwardResult result = new BackwardResult();
    		
    		DoubleMatrix output = forwardResult.getOutput();
    		DoubleMatrix outputDerivative = forwardResult.getOutputDerivative();
    		
    		result.cost = net.costFunc.valueAt(output, target);
    		DoubleMatrix outputDelta = net.costFunc.derivativeAt(output, target).muli(outputDerivative);
    		if (net.accuracyFunc != null) {
    			result.accuracy=net.accuracyFunc.valueAt(output, target);
    		}
    
    		result.deltas.add(outputDelta);
    		for (int i = net.layersNum - 1; i >= 0; i--) {
    			DoubleMatrix pdelta = result.deltas.get(result.deltas.size() - 1);
    
    			// 梯度计算,取所有样本平均
    			DoubleMatrix layerInput = i == 0 ? forwardResult.input
    					: forwardResult.finalResult.get(i - 1);
    			DoubleMatrix gradient = pdelta.mmul(layerInput.transpose()).div(
    					target.columns);
    			result.gradients.add(gradient);
    			// 偏置梯度
    			result.biasGradients.add(pdelta.rowMeans());
    
    			// 计算前一层delta,若i=0,delta为输入层误差,即input调整梯度,不作平均处理。
    			DoubleMatrix delta = net.weights.get(i).transpose().mmul(pdelta);
    			if (i > 0)
    				delta = delta.muli(forwardResult.derivativeResult.get(i - 1));
    			result.deltas.add(delta);
    		}
    		Collections.reverse(result.gradients);
    		Collections.reverse(result.biasGradients);
    		
    		//其它的delta都不需要。
    		DoubleMatrix inputDeltas=result.deltas.get(result.deltas.size()-1);
    		result.deltas.clear();
    		result.deltas.add(inputDeltas);
    		
    		return result;
    	}
    
    	@Override
    	public PropagationResult propagate(Net net, DataProvider provider) {
    		ForwardResult forwardResult=this.forward(net, provider.getInput());
    		BackwardResult backwardResult=this.backward(net, provider.getTarget(), forwardResult);
    		PropagationResult result=new PropagationResult(backwardResult);
    		result.output=forwardResult.getOutput();
    		return result;
    	}
    

     我们定义的PropagationResult略为:

    public class PropagationResult{
    		DoubleMatrix output;// 输出结果矩阵:outputLen*sampleLength
    		DoubleMatrix cost;// 误差矩阵:1*sampleLength
    		DoubleMatrix accuracy;// 准确度矩阵:1*sampleLength
    		private List<DoubleMatrix> gradients;// 权重梯度矩阵
    		private List<DoubleMatrix> biasGradients;// 偏置梯度矩阵
    		DoubleMatrix inputDeltas;//输入层delta矩阵:inputLen*sampleLength
    		
    		public CompactDoubleMatrix getCompact(){
    			return new CompactDoubleMatrix(gradients,biasGradients);
    		}
    		
    	}
    

     另一个实现了该接口的类为MiniBatchPropagation。他在内部用并行方式对样本进行传播,然后对每个minipatch结果进行综合,内部用到了BatchDataProviderFactory类和BasePropagation类。

    Trainer

    Trainer接口定义为:

    public interface Trainer {
        public void train(Net net,DataProvider provider);
    }
    

    简单的实现类为:

    public class CommonTrainer implements Trainer {
    	int ecophs;
    	Learner learner;
    	Propagation propagation;
    	List<Double> costs = new ArrayList<>();
    	List<Double> accuracys = new ArrayList<>();
    	public void trainOne(Net net, DataProvider provider) {
    		PropagationResult propResult = this.propagation
    				.propagate(net, provider);
    		learner.learn(net, propResult, provider);
    
    		Double cost = propResult.getMeanCost();
    		Double accuracy = propResult.getMeanAccuracy();
    		if (cost != null)
    			costs.add(cost);
    		if (accuracy != null)
    			accuracys.add(accuracy);
    	}
    
    	@Override
    	public void train(Net net, DataProvider provider) {
    		for (int i = 0; i < this.ecophs; i++) {
    			System.out.println("echops:"+i);
    			this.trainOne(net, provider);
    		}
    
    	}
    }
    

     简单的迭代echops此,没有智能停止功能,每次迭代用Learner调节权重。

    Learner

    Learner根据每次传播结果对网络权重进行调整,接口定义如下:

    public interface Learner<N extends Net,P extends DataProvider> {
        public void learn(N net,PropagationResult propResult,P provider);
    }
    

     一个简单的根据动量因子-自适应学习率进行调整的实现类为:

    public class MomentAdaptLearner<N extends Net, P extends DataProvider>
    		implements Learner<N, P> {
    	double moment = 0.7;
    	double lmd = 1.05;
    	double preCost = 0;
    	double eta = 0.01;
    	double currentEta = eta;
    	double currentMoment = moment;
    	CompactDoubleMatrix preGradient;
    
    	public MomentAdaptLearner(double moment, double eta) {
    		super();
    		this.moment = moment;
    		this.eta = eta;
    		this.currentEta = eta;
    		this.currentMoment = moment;
    	}
    
    	public MomentAdaptLearner() {
    
    	}
    
    	@Override
    	public void learn(N net, PropagationResult propResult, P provider) {
    		if (this.preGradient == null)
    			init(net, propResult, provider);
    
    		double cost = propResult.getMeanCost();
    		this.modifyParameter(cost);
    		System.out.println("current eta:" + this.currentEta);
    		System.out.println("current moment:" + this.currentMoment);
    		this.updateGradient(net, propResult, provider);
    
    	}
    
    	public void updateGradient(N net, PropagationResult propResult, P provider) {
    		CompactDoubleMatrix netCompact = this.getNetCompact(net, propResult,
    				provider);
    		CompactDoubleMatrix gradCompact = this.getGradientCompact(net,
    				propResult, provider);
    		gradCompact = gradCompact.mul(currentEta * (1 - currentMoment)).addi(
    				preGradient.mul(currentMoment));
    		netCompact.subi(gradCompact);
    		this.preGradient = gradCompact;
    	}
    
    	public CompactDoubleMatrix getNetCompact(N net,
    			PropagationResult propResult, P provider) {
    		return net.getCompact();
    	}
    
    	public CompactDoubleMatrix getGradientCompact(N net,
    			PropagationResult propResult, P provider) {
    		return propResult.getCompact();
    	}
    
    	public void modifyParameter(double cost) {
    
    		if (this.currentEta > 10) {
    			this.currentEta = 10;
    		} else if (this.currentEta < 0.0001) {
    			this.currentEta = 0.0001;
    		} else if (cost < this.preCost) {
    			this.currentEta *= 1.05;
    			this.currentMoment = moment;
    		} else if (cost < 1.04 * this.preCost) {
    			this.currentEta *= 0.7;
    			this.currentMoment *= 0.7;
    		} else {
    			this.currentEta = eta;
    			this.currentMoment = 0.1;
    		}
    		this.preCost = cost;
    	}
    
    	public void init(Net net, PropagationResult propResult, P provider) {
    		PropagationResult pResult = new PropagationResult(net);
    		preGradient = pResult.getCompact().dup();
    	}
    
    }
    

     在上面的代码中,我们可以看到CompactDoubleMatrix类对权重自变量的封装,使代码更加简洁,它在此表现出来的就是一个超矩阵,超向量,完全忽略了内部的结构。

    同时,其子类实现了同步更新字典的功能,代码也很简洁,只是简单的把需要调整的矩阵append到超矩阵中去即可,在父类中会统一对其进行调整:

    public class DictMomentLearner extends
    		MomentAdaptLearner<Net, DictDataProvider> {
    
    	public DictMomentLearner(double moment, double eta) {
    		super(moment, eta);
    	}
    
    	public DictMomentLearner() {
    		super();
    	}
    
    	@Override
    	public CompactDoubleMatrix getNetCompact(Net net,
    			PropagationResult propResult, DictDataProvider provider) {
    		CompactDoubleMatrix result = super.getNetCompact(net, propResult,
    				provider);
    		result.append(provider.getDict());
    		return result;
    	}
    
    	@Override
    	public CompactDoubleMatrix getGradientCompact(Net net,
    			PropagationResult propResult, DictDataProvider provider) {
    		CompactDoubleMatrix result = super.getGradientCompact(net, propResult,
    				provider);
    		result.append(DictUtil.getDictGradient(provider, propResult));
    		return result;
    	}
    
    	@Override
    	public void init(Net net, PropagationResult propResult,
    			DictDataProvider provider) {
    		DoubleMatrix preDictGradient = DoubleMatrix.zeros(
    				provider.getDict().rows, provider.getDict().columns);
    		super.init(net, propResult, provider);
    		this.preGradient.append(preDictGradient);
    	}
    }
    
  • 相关阅读:
    【原创】(九)Linux内存管理
    【原创】(八)Linux内存管理
    【原创】(六)Linux内存管理
    【原创】(四)Linux内存模型之Sparse Memory Model
    2019年总结
    被低估的.NET(下)-2019 中国.NET 开发者峰会
    《.NET内存管理宝典》阅读指南
    《 .NET并发编程实战》扩展阅读
    《 .NET并发编程实战》阅读指南
    《 .NET并发编程实战》阅读指南
  • 原文地址:https://www.cnblogs.com/wuseguang/p/4140408.html
Copyright © 2011-2022 走看看