zoukankan      html  css  js  c++  java
  • Spark MLlib 示例代码阅读

    阅读前提:有一定的机器学习基础, 本文重点面向的是应用,至于机器学习的相关复杂理论和优化理论,还是多多看论文,初学者推荐Ng的公开课

    /*
    * Licensed to the Apache Software Foundation (ASF) under one or more
    * contributor license agreements. See the NOTICE file distributed with
    * this work for additional information regarding copyright ownership.
    * The ASF licenses this file to You under the Apache License, Version 2.0
    * (the "License"); you may not use this file except in compliance with
    * the License. You may obtain a copy of the License at
    *
    * http://www.apache.org/licenses/LICENSE-2.0
    *
    * Unless required by applicable law or agreed to in writing, software
    * distributed under the License is distributed on an "AS IS" BASIS,
    * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    * See the License for the specific language governing permissions and
    * limitations under the License.
    */

    package org.apache.spark.examples.mllib

    import org.apache.log4j.{Level, Logger}
    import scopt.OptionParser

    import org.apache.spark.{SparkConf, SparkContext}
    import org.apache.spark.mllib.classification.{LogisticRegressionWithLBFGS, SVMWithSGD}
    import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
    import org.apache.spark.mllib.util.MLUtils
    import org.apache.spark.mllib.optimization.{SquaredL2Updater, L1Updater}

    /**
    * An example app for binary classification. Run with
    * {{{
    * bin/run-example org.apache.spark.examples.mllib.BinaryClassification
    * }}}
    * A synthetic dataset is located at `data/mllib/sample_binary_classification_data.txt`.
    * If you use it as a template to create your own app, please use `spark-submit` to submit your app.
    */
    object BinaryClassification {

    object Algorithm extends Enumeration {
    type Algorithm = Value
    val SVM, LR = Value
    }

    object RegType extends Enumeration {
    type RegType = Value
    val L1, L2 = Value
    }

    import Algorithm._
    import RegType._

    case class Params(
    input: String = null,
    numIterations: Int = 100,  迭代次数 submit时可以传递进来
    stepSize: Double = 1.0,     步长
    algorithm: Algorithm = LR,   默认的二分算法是逻辑回归
    regType: RegType = L2,  默认的正则化规则是L2 ,
    regParam: Double = 0.01) extends AbstractParams[Params]  L1 L2 正则化参数

    def main(args: Array[String]) {    
    val defaultParams = Params()

    val parser = new OptionParser[Params]("BinaryClassification") {
    head("BinaryClassification: an example app for binary classification.")
    opt[Int]("numIterations") 
    .text("number of iterations")
    .action((x, c) => c.copy(numIterations = x))
    opt[Double]("stepSize")
    .text("initial step size (ignored by logistic regression), " +
    s"default: ${defaultParams.stepSize}")
    .action((x, c) => c.copy(stepSize = x))
    opt[String]("algorithm")
    .text(s"algorithm (${Algorithm.values.mkString(",")}), " +
    s"default: ${defaultParams.algorithm}")
    .action((x, c) => c.copy(algorithm = Algorithm.withName(x)))
    opt[String]("regType")
    .text(s"regularization type (${RegType.values.mkString(",")}), " +
    s"default: ${defaultParams.regType}")
    .action((x, c) => c.copy(regType = RegType.withName(x)))
    opt[Double]("regParam")
    .text(s"regularization parameter, default: ${defaultParams.regParam}")
    arg[String]("<input>")
    .required()
    .text("input paths to labeled examples in LIBSVM format")
    .action((x, c) => c.copy(input = x))
    note(
    """
    |For example, the following command runs this app on a synthetic dataset:
    |
    | bin/spark-submit --class org.apache.spark.examples.mllib.BinaryClassification
    | examples/target/scala-*/spark-examples-*.jar
    | --algorithm LR --regType L2 --regParam 1.0
    | data/mllib/sample_binary_classification_data.txt
    """.stripMargin)
    }

    parser.parse(args, defaultParams).map { params =>                                      ////params 是参数列表 保存了线性回归或者svm的各种参数
    run(params)
    } getOrElse {
    sys.exit(1)
    }
    }

    def run(params: Params) {
    val conf = new SparkConf().setAppName(s"BinaryClassification with $params") 创建sparkConf 
    val sc = new SparkContext(conf)  创建sc sparkcontext

    Logger.getRootLogger.setLevel(Level.WARN)

    val examples = MLUtils.loadLibSVMFile(sc, params.input).cache()   //input 是我们的样本文件的路径

    val splits = examples.randomSplit(Array(0.8, 0.2))  将输入文本进行随机切割 80%的文件为训练文本 20%的文件为 测试文本
    val training = splits(0).cache()  训练数据
    val test = splits(1).cache()  测试数据

    val numTraining = training.count()
    val numTest = test.count()
    println(s"Training: $numTraining, test: $numTest.")

    examples.unpersist(blocking = false)

    val updater = params.regType match {   根据输入选择 是L1正则化还是L2正则化
    case L1 => new L1Updater()
    case L2 => new SquaredL2Updater()
    }

    val model = params.algorithm match {  根据输入参数选择是 逻辑回归还是 SVM
    case LR =>
    val algorithm = new LogisticRegressionWithLBFGS()
    algorithm.optimizer
    .setNumIterations(params.numIterations) 参数设置
    .setUpdater(updater)         参数设置
    .setRegParam(params.regParam)  参数设置
    algorithm.run(training).clearThreshold()        开始train
    case SVM =>
    val algorithm = new SVMWithSGD()
    algorithm.optimizer
    .setNumIterations(params.numIterations)
    .setStepSize(params.stepSize)
    .setUpdater(updater)
    .setRegParam(params.regParam)
    algorithm.run(training).clearThreshold()     开始train
    }

    val prediction = model.predict(test.map(_.features))  开始测试
    val predictionAndLabel = prediction.zip(test.map(_.label))

    val metrics = new BinaryClassificationMetrics(predictionAndLabel)

    println(s"Test areaUnderPR = ${metrics.areaUnderPR()}.")
    println(s"Test areaUnderROC = ${metrics.areaUnderROC()}.")

    sc.stop()
    }
    }

  • 相关阅读:
    SetConsoleScreenBufferSize 函数--设置控制台屏幕缓冲区大小
    GetConsoleScreenBufferInfo 函数--获取控制台屏幕缓冲区信息
    CONSOLE_SCREEN_BUFFER_INFO 结构体
    GetStdHandle 函数--获取标准设备的句柄
    设计模式之代理模式(Proxy Pattern)_远程代理解析
    设计模式之状态模式(State Pattern)
    设计模式之组合模式(Composite Pattern)
    设计模式之迭代器模式(Iterator Pattern)
    设计模式之模版方法模式(Template Method Pattern)
    设计模式之外观模式(Facade Pattern)
  • 原文地址:https://www.cnblogs.com/inspursu/p/4277216.html
Copyright © 2011-2022 走看看