zoukankan      html  css  js  c++  java
  • Spark线性回归实现优化

     1 import org.apache.log4j.{Level, Logger}
     2 import org.apache.spark.ml.feature.VectorAssembler
     3 import org.apache.spark.ml.regression.LinearRegression
     4 import org.apache.spark.sql.SparkSession
     5 
     6 /**
     7   * 线性回归
     8   * Created by zhen on 2018/11/12.
     9   */
    10 object LinearRegression {
    11   Logger.getLogger("org").setLevel(Level.WARN) // 设置日志级别
    12   def main(args: Array[String]) {
    13     val spark = SparkSession
    14       .builder()
    15       .appName("LinearRegression")
    16       .master("local[2]")
    17       .getOrCreate()
    18     val train_data = spark.sparkContext.textFile("E:/BDS/newsparkml/src/train.txt") // 加载数据
    19     val train_map_data = train_data.map{ row =>
    20         val split = row.split(",")
    21         (split(0).toDouble,split(1).toDouble,split(2).toDouble,split(3).toDouble,
    22           split(4).toDouble,split(5).toDouble,split(6).toDouble,split(7).toDouble)
    23       }
    24     val df = spark.sqlContext.createDataFrame(train_map_data)
    25     val colArray = Array("Population","Income","Illiteracy","LifeExp","HSGrad","Frost","Area")
    26     val train_df = df.toDF(colArray(0),colArray(1),colArray(2),colArray(3),"Murder",colArray(4),colArray(5),colArray(6))
    27     val assembler = new VectorAssembler()
    28       .setInputCols(colArray)
    29       .setOutputCol("features")
    30     val vectDF = assembler.transform(train_df)
    31     val weights = Array(0.8,0.2) //设置训练集和测试集的比例
    32     val split_data = vectDF.randomSplit(weights) // 拆分训练集和测试集
    33     // 创建模型对象
    34     val linearRegression = new LinearRegression()
    35       .setFeaturesCol("features")
    36       .setLabelCol("Murder")
    37       .setFitIntercept(true)
    38       .setMaxIter(10)
    39       .setRegParam(0.3)// 正则化
    40       .setElasticNetParam(0.8)
    41     // 训练模型
    42     val lrModel = linearRegression.fit(split_data(0))
    43     // 查看模型参数
    44     //lrModel.extractParamMap()
    45     println(s"Cofficients:${lrModel.coefficients} Intercept:${lrModel.intercept}")
    46     //模型评估
    47     val trainingSummary = lrModel.summary
    48     println(s"objectiveHistoryList:${trainingSummary.objectiveHistory.toList}")
    49     println(s"r2:${trainingSummary.r2}")
    50     // 预测
    51     val predictions = lrModel.transform(split_data(1))
    52     val predict_result = predictions.selectExpr("features","Murder","round(prediction,1) as prediction") // 保存一位小数
    53     println("训练集数据------------------------------真实值--预测值")
    54     predict_result.foreach(println(_))
    55   }
    56 }

    结果:

  • 相关阅读:
    关闭编辑easyui datagrid table
    sql 保留两位小数+四舍五入
    easyui DataGrid 工具类之 util js
    easyui DataGrid 工具类之 后台生成列
    easyui DataGrid 工具类之 WorkbookUtil class
    easyui DataGrid 工具类之 TableUtil class
    easyui DataGrid 工具类之 Utils class
    easyui DataGrid 工具类之 列属性class
    oracle 卸载
    “云时代架构”经典文章阅读感想七
  • 原文地址:https://www.cnblogs.com/yszd/p/9952268.html
Copyright © 2011-2022 走看看