zoukankan      html  css  js  c++  java
  • 用java写bp神经网络(四)

    接上篇。

    在(一)和(二)中,程序的体系是Net,Propagation,Trainer,Learner,DataProvider。这篇重构这个体系。

    Net

    首先是Net,在上篇重新定义了激活函数和误差函数后,内容大致是这样的:

    List<DoubleMatrix> weights = new ArrayList<DoubleMatrix>();
    	List<DoubleMatrix> bs = new ArrayList<>();
    	List<ActivationFunction> activations = new ArrayList<>();
    	CostFunction costFunc;
    	CostFunction accuracyFunc;
    	int[] nodesNum;
    	int layersNum;
    
    public CompactDoubleMatrix getCompact(){
    		return new CompactDoubleMatrix(this.weights,this.bs);
    	}
    

     函数getCompact()生成对应的超矩阵。

    DataProvider

    DataProvider是数据的提供者。

    public interface DataProvider {
        DoubleMatrix getInput();
        DoubleMatrix getTarget();
    }
    

     如果输入为向量,还包含一个向量字典。

    public interface DictDataProvider extends DataProvider {
    	public DoubleMatrix getIndexs();
    	public DoubleMatrix getDict();
    }
    

     每一列为一个样本。getIndexs()返回输入向量在字典中的索引。

    我写了一个有用的类BatchDataProviderFactory来对样本进行批量分割,分割成minibatch。

    int batchSize;
    	int dataLen;
    	DataProvider originalProvider;
    	List<Integer> endPositions;
    	List<DataProvider> providers;
    
    	public BatchDataProviderFactory(int batchSize, DataProvider originalProvider) {
    		super();
    		this.batchSize = batchSize;
    		this.originalProvider = originalProvider;
    		this.dataLen = this.originalProvider.getTarget().columns;
    		this.initEndPositions();
    		this.initProviders();
    	}
    
    	public BatchDataProviderFactory(DataProvider originalProvider) {
    		this(4, originalProvider);
    	}
    
    	public List<DataProvider> getProviders() {
    		return providers;
    	}
    

     batchSize指明要分多少批,getProviders返回生成的minibatch,被分的原始数据为originalProvider。

    Propagation

    Propagation负责对神经网络的正向传播过程和反向传播过程。接口定义如下:

    public interface Propagation {
    	public PropagationResult propagate(Net net,DataProvider provider);
    }
    

     传播函数propagate用指定数据对指定网络进行传播操作,返回执行结果。

    BasePropagation实现了该接口,实现了简单的反向传播:

    public class BasePropagation implements Propagation{
    
    	// 多个样本。
    	protected ForwardResult forward(Net net,DoubleMatrix input) {
    		
    		ForwardResult result = new ForwardResult();
    		result.input = input;
    		DoubleMatrix currentResult = input;
    		int index = -1;
    		for (DoubleMatrix weight : net.weights) {
    			index++;
    			DoubleMatrix b = net.bs.get(index);
    			final ActivationFunction activation = net.activations
    					.get(index);
    			currentResult = weight.mmul(currentResult).addColumnVector(b);
    			result.netResult.add(currentResult);
    
    			// 乘以导数
    			DoubleMatrix derivative = activation.derivativeAt(currentResult);
    			result.derivativeResult.add(derivative);
    			
    			currentResult = activation.valueAt(currentResult);
    			result.finalResult.add(currentResult);
    
    		}
    
    		result.netResult=null;// 不再需要。
    		
    		return result;
    	}
    
    	
    
        // 多个样本梯度平均值。
    	protected BackwardResult backward(Net net,DoubleMatrix target,
    			ForwardResult forwardResult) {
    		BackwardResult result = new BackwardResult();
    		
    		DoubleMatrix output = forwardResult.getOutput();
    		DoubleMatrix outputDerivative = forwardResult.getOutputDerivative();
    		
    		result.cost = net.costFunc.valueAt(output, target);
    		DoubleMatrix outputDelta = net.costFunc.derivativeAt(output, target).muli(outputDerivative);
    		if (net.accuracyFunc != null) {
    			result.accuracy=net.accuracyFunc.valueAt(output, target);
    		}
    
    		result.deltas.add(outputDelta);
    		for (int i = net.layersNum - 1; i >= 0; i--) {
    			DoubleMatrix pdelta = result.deltas.get(result.deltas.size() - 1);
    
    			// 梯度计算,取所有样本平均
    			DoubleMatrix layerInput = i == 0 ? forwardResult.input
    					: forwardResult.finalResult.get(i - 1);
    			DoubleMatrix gradient = pdelta.mmul(layerInput.transpose()).div(
    					target.columns);
    			result.gradients.add(gradient);
    			// 偏置梯度
    			result.biasGradients.add(pdelta.rowMeans());
    
    			// 计算前一层delta,若i=0,delta为输入层误差,即input调整梯度,不作平均处理。
    			DoubleMatrix delta = net.weights.get(i).transpose().mmul(pdelta);
    			if (i > 0)
    				delta = delta.muli(forwardResult.derivativeResult.get(i - 1));
    			result.deltas.add(delta);
    		}
    		Collections.reverse(result.gradients);
    		Collections.reverse(result.biasGradients);
    		
    		//其它的delta都不需要。
    		DoubleMatrix inputDeltas=result.deltas.get(result.deltas.size()-1);
    		result.deltas.clear();
    		result.deltas.add(inputDeltas);
    		
    		return result;
    	}
    
    	@Override
    	public PropagationResult propagate(Net net, DataProvider provider) {
    		ForwardResult forwardResult=this.forward(net, provider.getInput());
    		BackwardResult backwardResult=this.backward(net, provider.getTarget(), forwardResult);
    		PropagationResult result=new PropagationResult(backwardResult);
    		result.output=forwardResult.getOutput();
    		return result;
    	}
    

     我们定义的PropagationResult略为:

    public class PropagationResult{
    		DoubleMatrix output;// 输出结果矩阵:outputLen*sampleLength
    		DoubleMatrix cost;// 误差矩阵:1*sampleLength
    		DoubleMatrix accuracy;// 准确度矩阵:1*sampleLength
    		private List<DoubleMatrix> gradients;// 权重梯度矩阵
    		private List<DoubleMatrix> biasGradients;// 偏置梯度矩阵
    		DoubleMatrix inputDeltas;//输入层delta矩阵:inputLen*sampleLength
    		
    		public CompactDoubleMatrix getCompact(){
    			return new CompactDoubleMatrix(gradients,biasGradients);
    		}
    		
    	}
    

     另一个实现了该接口的类为MiniBatchPropagation。他在内部用并行方式对样本进行传播,然后对每个minipatch结果进行综合,内部用到了BatchDataProviderFactory类和BasePropagation类。

    Trainer

    Trainer接口定义为:

    public interface Trainer {
        public void train(Net net,DataProvider provider);
    }
    

    简单的实现类为:

    public class CommonTrainer implements Trainer {
    	int ecophs;
    	Learner learner;
    	Propagation propagation;
    	List<Double> costs = new ArrayList<>();
    	List<Double> accuracys = new ArrayList<>();
    	public void trainOne(Net net, DataProvider provider) {
    		PropagationResult propResult = this.propagation
    				.propagate(net, provider);
    		learner.learn(net, propResult, provider);
    
    		Double cost = propResult.getMeanCost();
    		Double accuracy = propResult.getMeanAccuracy();
    		if (cost != null)
    			costs.add(cost);
    		if (accuracy != null)
    			accuracys.add(accuracy);
    	}
    
    	@Override
    	public void train(Net net, DataProvider provider) {
    		for (int i = 0; i < this.ecophs; i++) {
    			System.out.println("echops:"+i);
    			this.trainOne(net, provider);
    		}
    
    	}
    }
    

     简单的迭代echops此,没有智能停止功能,每次迭代用Learner调节权重。

    Learner

    Learner根据每次传播结果对网络权重进行调整,接口定义如下:

    public interface Learner<N extends Net,P extends DataProvider> {
        public void learn(N net,PropagationResult propResult,P provider);
    }
    

     一个简单的根据动量因子-自适应学习率进行调整的实现类为:

    public class MomentAdaptLearner<N extends Net, P extends DataProvider>
    		implements Learner<N, P> {
    	double moment = 0.7;
    	double lmd = 1.05;
    	double preCost = 0;
    	double eta = 0.01;
    	double currentEta = eta;
    	double currentMoment = moment;
    	CompactDoubleMatrix preGradient;
    
    	public MomentAdaptLearner(double moment, double eta) {
    		super();
    		this.moment = moment;
    		this.eta = eta;
    		this.currentEta = eta;
    		this.currentMoment = moment;
    	}
    
    	public MomentAdaptLearner() {
    
    	}
    
    	@Override
    	public void learn(N net, PropagationResult propResult, P provider) {
    		if (this.preGradient == null)
    			init(net, propResult, provider);
    
    		double cost = propResult.getMeanCost();
    		this.modifyParameter(cost);
    		System.out.println("current eta:" + this.currentEta);
    		System.out.println("current moment:" + this.currentMoment);
    		this.updateGradient(net, propResult, provider);
    
    	}
    
    	public void updateGradient(N net, PropagationResult propResult, P provider) {
    		CompactDoubleMatrix netCompact = this.getNetCompact(net, propResult,
    				provider);
    		CompactDoubleMatrix gradCompact = this.getGradientCompact(net,
    				propResult, provider);
    		gradCompact = gradCompact.mul(currentEta * (1 - currentMoment)).addi(
    				preGradient.mul(currentMoment));
    		netCompact.subi(gradCompact);
    		this.preGradient = gradCompact;
    	}
    
    	public CompactDoubleMatrix getNetCompact(N net,
    			PropagationResult propResult, P provider) {
    		return net.getCompact();
    	}
    
    	public CompactDoubleMatrix getGradientCompact(N net,
    			PropagationResult propResult, P provider) {
    		return propResult.getCompact();
    	}
    
    	public void modifyParameter(double cost) {
    
    		if (this.currentEta > 10) {
    			this.currentEta = 10;
    		} else if (this.currentEta < 0.0001) {
    			this.currentEta = 0.0001;
    		} else if (cost < this.preCost) {
    			this.currentEta *= 1.05;
    			this.currentMoment = moment;
    		} else if (cost < 1.04 * this.preCost) {
    			this.currentEta *= 0.7;
    			this.currentMoment *= 0.7;
    		} else {
    			this.currentEta = eta;
    			this.currentMoment = 0.1;
    		}
    		this.preCost = cost;
    	}
    
    	public void init(Net net, PropagationResult propResult, P provider) {
    		PropagationResult pResult = new PropagationResult(net);
    		preGradient = pResult.getCompact().dup();
    	}
    
    }
    

     在上面的代码中,我们可以看到CompactDoubleMatrix类对权重自变量的封装,使代码更加简洁,它在此表现出来的就是一个超矩阵,超向量,完全忽略了内部的结构。

    同时,其子类实现了同步更新字典的功能,代码也很简洁,只是简单的把需要调整的矩阵append到超矩阵中去即可,在父类中会统一对其进行调整:

    public class DictMomentLearner extends
    		MomentAdaptLearner<Net, DictDataProvider> {
    
    	public DictMomentLearner(double moment, double eta) {
    		super(moment, eta);
    	}
    
    	public DictMomentLearner() {
    		super();
    	}
    
    	@Override
    	public CompactDoubleMatrix getNetCompact(Net net,
    			PropagationResult propResult, DictDataProvider provider) {
    		CompactDoubleMatrix result = super.getNetCompact(net, propResult,
    				provider);
    		result.append(provider.getDict());
    		return result;
    	}
    
    	@Override
    	public CompactDoubleMatrix getGradientCompact(Net net,
    			PropagationResult propResult, DictDataProvider provider) {
    		CompactDoubleMatrix result = super.getGradientCompact(net, propResult,
    				provider);
    		result.append(DictUtil.getDictGradient(provider, propResult));
    		return result;
    	}
    
    	@Override
    	public void init(Net net, PropagationResult propResult,
    			DictDataProvider provider) {
    		DoubleMatrix preDictGradient = DoubleMatrix.zeros(
    				provider.getDict().rows, provider.getDict().columns);
    		super.init(net, propResult, provider);
    		this.preGradient.append(preDictGradient);
    	}
    }
    
  • 相关阅读:
    PTA —— 基础编程题目集 —— 函数题 —— 61 简单输出整数 (10 分)
    PTA —— 基础编程题目集 —— 函数题 —— 61 简单输出整数 (10 分)
    练习2.13 不用库函数,写一个高效计算ln N的C函数
    练习2.13 不用库函数,写一个高效计算ln N的C函数
    练习2.13 不用库函数,写一个高效计算ln N的C函数
    迷宫问题 POJ 3984
    UVA 820 Internet Bandwidth (因特网带宽)(最大流)
    UVA 1001 Say Cheese(奶酪里的老鼠)(flod)
    UVA 11105 Semiprime Hnumbers(H半素数)
    UVA 557 Burger(汉堡)(dp+概率)
  • 原文地址:https://www.cnblogs.com/wuseguang/p/4140408.html
Copyright © 2011-2022 走看看