zoukankan      html  css  js  c++  java
  • Spark机器学习(1):线性回归算法

    线性回归算法,是利用数理统计中回归分析,来确定两种或两种以上变量间相互依赖的定量关系的一种统计分析方法。

    1. 梯度下降法

    线性回归可以使用最小二乘法,但是速度比较慢,因此一般使用梯度下降法(Gradient Descent),梯度下降法又分为批量梯度下降法(Batch Gradient Descent)和随机梯度下降法(Stochastic Gradient Descent)。批量梯度下降法每次迭代需要使用训练集里面的所有数据,当训练集数据量较大时,速度就很慢;随机梯度下降法每次迭代只需要一个样本的数据,速度较快,对于大数据集,可能只需要使用少部分数据就达到收敛值,虽然有可能在最小值周围震荡,但是大多数情况下效果不错,所以,一般使用随机梯度下降法。

    2. Mllib的线性回归

    Mllib的线性回归采用的是随机梯度下降法。直接上代码:

    import org.apache.log4j.{ Level, Logger }
    import org.apache.spark.{ SparkConf, SparkContext }
    import org.apache.spark.mllib.regression.LinearRegressionWithSGD
    import org.apache.spark.mllib.regression.LabeledPoint
    import org.apache.spark.mllib.linalg.Vectors
    
    object LinearRegression {
    
      def main(args: Array[String]) {
        // 设置运行环境
        val conf = new SparkConf().setAppName("Linear Regression Test").setMaster("spark://master:7077").setJars(Seq("E:\Intellij\Projects\MachineLearning\MachineLearning.jar"))
        val sc = new SparkContext(conf)
        Logger.getRootLogger.setLevel(Level.WARN)
    
        //读取样本数据,生成RDD
        val data_path = "hdfs://master:9000/ml/data/lpsa.data"
        val dataRDD = sc.textFile(data_path)
        val examples = dataRDD.map { line =>
          val parts = line.split(',')
          LabeledPoint(parts(0).toDouble, Vectors.dense(parts(1).split(' ').map(_.toDouble)))
        }.cache()// 迭代次数
        val numIterations = 100
        // 步长
        val stepSize = 0.5
        // 选取样本的比例
        val miniBatchFraction = 1.0
        // 用随机梯度下降模型训练
        val sgdModel = LinearRegressionWithSGD.train(examples, numIterations, stepSize, miniBatchFraction)
    
        // 对样本进行测试
        val prediction = sgdModel.predict(examples.map(_.features))
        val predictionAndLabel = prediction.zip(examples.map(_.label))
        // 选取前100个样本
        val show_predict = predictionAndLabel.take(100)
        println("Prediction" + "	" + "Label" + "	" + "Diff")
        for (i <- 0 to show_predict.length - 1) {
          val diff = show_predict(i)._1-show_predict(i)._2
          println(show_predict(i)._1 + "	" + show_predict(i)._2 + "	" + diff)
        }
    
      }
    
    }

    部分运行结果:

  • 相关阅读:
    Linux 配置 SSL 证书
    freemarker 异常处理
    StringTemplateLoader的用法
    序列的重点知识小结
    Linux下安装lrzsz上传下载工具
    ajax技术
    Response对象介绍(服务器到客户端)
    Request对象介绍(客户端到服务器)
    JSP--内置对象&动作标签介绍
    JSP--常用指令
  • 原文地址:https://www.cnblogs.com/mstk/p/7002775.html
Copyright © 2011-2022 走看看