zoukankan      html  css  js  c++  java
  • jrae源码解析(一)

    jare用java实现了论文《Semi-Supervised Recursive Autoencoders for Predicting Sentiment Distributions》中提出的算法——基于半监督的递归自动编码机,用来预测情感分类。详情可查看论文内容,代码git地址为:https://github.com/sancha/jrae。

    鸟瞰

    主函数训练流程

    FineTunableTheta tunedTheta = rae.train(params);// 根据参数和数据训练神经网络权重
          tunedTheta.Dump(params.ModelFile);
    
          System.out.println("RAE trained. The model file is saved in "
              + params.ModelFile);
        // 特征抽取器
          RAEFeatureExtractor fe = new RAEFeatureExtractor(params.EmbeddingSize,
              tunedTheta, params.AlphaCat, params.Beta, params.CatSize,
              params.Dataset.Vocab.size(), rae.f);
        // 获取训练数据
          List<LabeledDatum<Double, Integer>> classifierTrainingData = fe
              .extractFeaturesIntoArray(params.Dataset, params.Dataset.Data,
                  params.TreeDumpDir);
        // 测试精度
          SoftmaxClassifier<Double, Integer> classifier = new SoftmaxClassifier<Double, Integer>();
          Accuracy TrainAccuracy = classifier.train(classifierTrainingData);
          System.out.println("Train Accuracy :" + TrainAccuracy.toString());
    

    几个重要的接口以及实现类

    1、Minimizer<T extends DifferentiableFunction>

    public interface Minimizer<T extends DifferentiableFunction> {
    
      /**
       * Attempts to find an unconstrained minimum of the objective
       * <code>function</code> starting at <code>initial</code>, within
       * <code>functionTolerance</code>.
       *
       * @param function          the objective function
       * @param functionTolerance a <code>double</code> value
       * @param initial           a initial feasible point
       * @return Unconstrained minimum of function
       */
      double[] minimize(T function, double functionTolerance, double[] initial);
      double[] minimize(T function, double functionTolerance, double[] initial, int maxIterations);
    
    }
    

     如其所述,该接口用来找到给定目标函数的最小化极值,目标函数必须是处处可微的,并实现DifferentiableFunction接口。functionTolerance是最小误差,initial是初始点,maxIterations是最大迭代次数。

    public interface DifferentiableFunction extends Function {
      double[] derivativeAt(double[] x);
    }
    
    public interface Function {
      int dimension();
      double valueAt(double[] x);
    }
    

     QNMinimizer类实现了该接口,利用L-BFGS优化算法对目标函数进行优化,下面是算法的注释:

    /**
     * This code is part of the Stanford NLP Toolkit.
     * 
     * 
     * An implementation of L-BFGS for Quasi Newton unconstrained minimization.
     * 
     * The general outline of the algorithm is taken from: <blockquote> <i>Numerical
     * Optimization</i> (second edition) 2006 Jorge Nocedal and Stephen J. Wright
     * </blockquote> A variety of different options are available.
     * 
     * <h3>LINESEARCHES</h3>
     * 
     * BACKTRACKING: This routine simply starts with a guess for step size of 1. If
     * the step size doesn't supply a sufficient decrease in the function value the
     * step is updated through step = 0.1*step. This method is certainly simpler,
     * but doesn't allow for an increase in step size, and isn't well suited for
     * Quasi Newton methods.
     * 
     * MINPACK: This routine is based off of the implementation used in MINPACK.
     * This routine finds a point satisfying the Wolfe conditions, which state that
     * a point must have a sufficiently smaller function value, and a gradient of
     * smaller magnitude. This provides enough to prove theoretically quadratic
     * convergence. In order to find such a point the linesearch first finds an
     * interval which must contain a satisfying point, and then progressively
     * reduces that interval all using cubic or quadratic interpolation.
     * 
     * 
     * SCALING: L-BFGS allows the initial guess at the hessian to be updated at each
     * step. Standard BFGS does this by approximating the hessian as a scaled
     * identity matrix. To use this method set the scaleOpt to SCALAR. A better way
     * of approximate the hessian is by using a scaling diagonal matrix. The
     * diagonal can then be updated as more information comes in. This method can be
     * used by setting scaleOpt to DIAGONAL.
     * 
     * 
     * CONVERGENCE: Previously convergence was gauged by looking at the average
     * decrease per step dividing that by the current value and terminating when
     * that value because smaller than TOL. This method fails when the function
     * value approaches zero, so two other convergence criteria are used. The first
     * stores the initial gradient norm |g0|, then terminates when the new gradient
     * norm, |g| is sufficiently smaller: i.e., |g| < eps*|g0| the second checks
     * if |g| < eps*max( 1 , |x| ) which is essentially checking to see if the
     * gradient is numerically zero.
     * 
     * Each of these convergence criteria can be turned on or off by setting the
     * flags: <blockquote><code>
     * private boolean useAveImprovement = true;
     * private boolean useRelativeNorm = true;
     * private boolean useNumericalZero = true;
     * </code></blockquote>
     * 
     * To use the QNMinimizer first construct it using <blockquote><code>
     * QNMinimizer qn = new QNMinimizer(mem, true)
     * </code>
     * </blockquote> mem - the number of previous estimate vector pairs to store,
     * generally 15 is plenty. true - this tells the QN to use the MINPACK
     * linesearch with DIAGONAL scaling. false would lead to the use of the criteria
     * used in the old QNMinimizer class.
     */
    

     OK,可以结合我前面文章,了解L-BFGS算法的原理,然后该类实现了这个算法,并且在某些细节上做了一些修改。具体的实现算法先略去不议,日后再说。

    2、DifferentiableFunction

    DifferentiableFunction定义上面已经给出,对应一个可微的函数。抽象类MemoizedDifferentiableFunction实现了这个接口,封装了一些通用的代码:

    public abstract class MemoizedDifferentiableFunction implements DifferentiableFunction {
    	protected double[] prevQuery, gradient;
    	protected double value;
    	protected int evalCount;
    	
    	protected void initPrevQuery()
    	{
    		prevQuery = new double[ dimension() ];
    	}
    	
    	protected boolean requiresEvaluation(double[] x)
    	{
    		if(DoubleArrays.equals(x,prevQuery))
    			return false;
    		
    		System.arraycopy(x, 0, prevQuery, 0, x.length);
    		evalCount++;	
    		return true;
    	}
    	
    	@Override
    	public double[] derivativeAt(double[] x){
    		if(DoubleArrays.equals(x,prevQuery))
    			return gradient;
    		valueAt(x);
    		return gradient;
    	}
    }
    

     封装的通用方法为,保存了上次请求的参数,如果传入参数已经被请求过,直接返回结果即可;保存了执行请求的次数;实现了求导流程,首先调用valueAt求得当前值$f(x)$,然后返回梯度(导数),valueAt由子类实现,即约定子类在计算$f(x)$的时候顺便计算好了$f'(x)$,然后保存到gradient变量中。

    两个子类分别为RAECost和SoftmaxCost。

    SoftmaxCost类表示,在给定样本的情况下,计算出给定权重的误差,导数指明减小误差的梯度。对应的是一个2层的网络,输入层为features(特征),输出层为label,并且转换函数为softmax(能量函数)。

    RAECost类表示,在给定样本的情况下,计算出给定权重的误差,误差包括生成递归树的误差与label分类的误差只和,导数指明梯度,也是两者梯度之和。

    在调用Minimizer接口进行优化时,传入的第一个参数即是RAECost对象,优化完毕时即是训练完毕时。

    参考文献:

    http://www.socher.org/index.php/Main/Semi-SupervisedRecursiveAutoencodersForPredictingSentimentDistributions

  • 相关阅读:
    334 Increasing Triplet Subsequence 递增的三元子序列
    332 Reconstruct Itinerary 重建行程单
    331 Verify Preorder Serialization of a Binary Tree 验证二叉树的前序序列化
    330 Patching Array
    329 Longest Increasing Path in a Matrix 矩阵中的最长递增路径
    328 Odd Even Linked List 奇偶链表
    327 Count of Range Sum 区间和计数
    326 Power of Three 3的幂
    Java中的Class.forName
    巧用Java中Calendar工具类
  • 原文地址:https://www.cnblogs.com/wuseguang/p/4106689.html
Copyright © 2011-2022 走看看