zoukankan      html  css  js  c++  java
  • Spark机器学习(1):线性回归算法

    线性回归算法,是利用数理统计中回归分析,来确定两种或两种以上变量间相互依赖的定量关系的一种统计分析方法。

    1. 梯度下降法

    线性回归可以使用最小二乘法,但是速度比较慢,因此一般使用梯度下降法(Gradient Descent),梯度下降法又分为批量梯度下降法(Batch Gradient Descent)和随机梯度下降法(Stochastic Gradient Descent)。批量梯度下降法每次迭代需要使用训练集里面的所有数据,当训练集数据量较大时,速度就很慢;随机梯度下降法每次迭代只需要一个样本的数据,速度较快,对于大数据集,可能只需要使用少部分数据就达到收敛值,虽然有可能在最小值周围震荡,但是大多数情况下效果不错,所以,一般使用随机梯度下降法。

    2. Mllib的线性回归

    Mllib的线性回归采用的是随机梯度下降法。直接上代码:

    import org.apache.log4j.{ Level, Logger }
    import org.apache.spark.{ SparkConf, SparkContext }
    import org.apache.spark.mllib.regression.LinearRegressionWithSGD
    import org.apache.spark.mllib.regression.LabeledPoint
    import org.apache.spark.mllib.linalg.Vectors
    
    object LinearRegression {
    
      def main(args: Array[String]) {
        // 设置运行环境
        val conf = new SparkConf().setAppName("Linear Regression Test").setMaster("spark://master:7077").setJars(Seq("E:\Intellij\Projects\MachineLearning\MachineLearning.jar"))
        val sc = new SparkContext(conf)
        Logger.getRootLogger.setLevel(Level.WARN)
    
        //读取样本数据,生成RDD
        val data_path = "hdfs://master:9000/ml/data/lpsa.data"
        val dataRDD = sc.textFile(data_path)
        val examples = dataRDD.map { line =>
          val parts = line.split(',')
          LabeledPoint(parts(0).toDouble, Vectors.dense(parts(1).split(' ').map(_.toDouble)))
        }.cache()// 迭代次数
        val numIterations = 100
        // 步长
        val stepSize = 0.5
        // 选取样本的比例
        val miniBatchFraction = 1.0
        // 用随机梯度下降模型训练
        val sgdModel = LinearRegressionWithSGD.train(examples, numIterations, stepSize, miniBatchFraction)
    
        // 对样本进行测试
        val prediction = sgdModel.predict(examples.map(_.features))
        val predictionAndLabel = prediction.zip(examples.map(_.label))
        // 选取前100个样本
        val show_predict = predictionAndLabel.take(100)
        println("Prediction" + "	" + "Label" + "	" + "Diff")
        for (i <- 0 to show_predict.length - 1) {
          val diff = show_predict(i)._1-show_predict(i)._2
          println(show_predict(i)._1 + "	" + show_predict(i)._2 + "	" + diff)
        }
    
      }
    
    }

    部分运行结果:

  • 相关阅读:
    微访谈之1:解答各位朋友关心的问题
    深入浅出SQL Server中的死锁(实战篇)
    怎样玩转千万级别的数据
    Another MySQL daemon already running with the same unix socket
    c++ undefined reference to mysqlinit
    Another MySQL daemon already running with the same unix socket
    linxu 挂载分区
    C# RSA
    谷歌地图实现车辆轨迹移动播放(google map api)
    百度地图实现车辆轨迹移动播放(baidu map api)
  • 原文地址:https://www.cnblogs.com/mstk/p/7002775.html
Copyright © 2011-2022 走看看