zoukankan      html  css  js  c++  java
  • artificial neural network in spark MLLib

    神经网络模型

    每个node包含两种操作:线性变换(仿射变换)和激发函数(activation function)。

    其中仿射变换是通用的,而激发函数可以很多种,如下图。

    MLLib中实现ANN

    使用两层(Layer)来对应模型中的一层:

    • AffineLayer 仿射变换: output = W · input + b
    • 如果是最后一层,使用SoftmaxLayerWithCrossEntropyLoss或者SigmoidLayerWithSquaredError;如果是中间层,则使用functionalLayer(new SigmoidFunction()). 目前MLlib只支持sigmoid函数,实际上ReLU激发函数更普遍

    BP算法计算Gradient的四个步骤:

    对照BP算法的步骤,可以发现分隔成Affine和Activation的好处。BP1和BP2中的计算,不同的activation函数有不同的计算形式,将affine变换和activation函数解耦方便组合,进而方便形成各种类型的神经网络。

    MLLib FeedForward Trainer

    训练器重要模块如下:

    ANN模型中每层对应AffineLayer + FunctionalLayerModel OR SofrmaxLayerModelWIthCrossEntropyLoss
    每个LayerModel实现三个函数:eval, computePrevDelta, grad, 作为输出层的SoftmaxLayerModel有些特殊,额外具有LossFunction特性。
    可验证affine+activation LayerModel的计算组合跟BP1-4一致。

    AffineLayerModel (仿射变换层)

    • eval
      ( ext{output} = W cdot ext{input} + b)

    • computePrevDelta
      (prevdelta = W * delta)

    • grad
      $dot{W} = input cdot delta^l / ext{data size} $
      input is (a^{l-1}),前一层的激发函数输出
      (dot{b} = delta^l / ext{data size})

    FunctionalLayerModel(activate function (sigma))

    作为affineModel的activation model,只影响prev(delta) 的计算,grad不计算

    • eval
      ( ext{output} = sigma ( ext{input}))

    • computePrevDelta
      (delta :=delta * sigma'( ext{input}))

    • grad
      pass

    SoftmaxLayerModelWithCrossEntropyLoss

    作为最后一层激发函数,这一层很特殊。

    • eval
      计算参见手写公式。

    • computePrevDelta
      不计算

    • grad
      不计算

    • loss
      计算(delta^L),公式推导参见手写公式,代码如下:

        ApplyInPlace(output, target, delta, (o: Double, t: Double) => o - t)
    

    返回loss

    Softmax输出层的激发函数:
    (a^L_j = frac{e^{z^L_j}}{sum_k e^{z^L_k}})
    计算BP1:(delta^L_j = a^L_j -y_j)

    训练mnist手写数字识别

    import org.apache.spark.ml.classification.MultilayerPerceptronClassifier
    import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
    import org.apache.spark.ml.linalg.Vectors
    import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
    
    object ann extends App {
      val spark = SparkSession
        .builder
        .appName("ANN for MNIST")
        .master("local[3]")
        .getOrCreate()
      spark.sparkContext.setLogLevel("ERROR")
    
      import spark.implicits._
    
      // Load the data stored in text as a DataFrame.
      val dataRdd: DataFrame= spark.sparkContext.textFile("handson-ml/data/train.csv")
        .map {
          line =>
            val linesp = line.split(",")
            val linespDouble = linesp.map(f => f.toDouble)
            (linespDouble.head, Vectors.dense(linespDouble.takeRight(linespDouble.length - 1)))
        }.toDF("label","features")
    
    
      val data = dataRdd
      // Split the data into train and test
      val splits: Array[DataFrame] = data.randomSplit(Array(0.6, 0.4), seed = 1234L)
      val train: Dataset[Row] = splits(0)
      val test: Dataset[Row] = splits(1)
    
    
      val layers = Array[Int](28*28, 300, 100, 10)
    
      // create the trainer and set its parameters
      val trainer = new MultilayerPerceptronClassifier()
        .setLayers(layers)
        .setBlockSize(128)
        .setSeed(1234L)
        .setMaxIter(100)
        .setLabelCol("label")
        .setFeaturesCol("features")
    
    
      // train the model
      val model = trainer.fit(train)
    
      // compute accuracy on the test set
      val result = model.transform(test)
      val predictionAndLabels = result.select("prediction", "label")
      val evaluator = new MulticlassClassificationEvaluator()
        .setMetricName("accuracy")
    
      println("Test set accuracy = " + evaluator.evaluate(predictionAndLabels))
    }
    

    后记

    测试集结果精度为96.68%。实际上并不高,同样的数据集使用TensorFlow训练,activation function选择ReLU,同样使用Softmax作为输出,结果可以达到98%以上。Sigmoid函数容易带来vanishing gradients问题,导致学习曲线变平。

  • 相关阅读:
    oracle 对应的JDBC驱动 版本
    Java Web中如何访问WEB-INF下的XML文件
    网站制作越简单越好(一):css样式命名规范
    HTTPClient以WebAPI方式发送formData数据上传文件
    NetCore(依赖注入)
    JS a标签 onClick问题
    NetCore的配置管理(1)
    Centos 系统安装NetCore SDK命令以及一系列操作(3)
    Centos 系统安装NetCore SDK命令以及一系列操作(2)
    Centos 系统安装NetCore SDK命令以及一系列操作(1)
  • 原文地址:https://www.cnblogs.com/luweiseu/p/7843761.html
Copyright © 2011-2022 走看看