zoukankan      html  css  js  c++  java
  • Spark线性回归实现优化

     1 import org.apache.log4j.{Level, Logger}
     2 import org.apache.spark.ml.feature.VectorAssembler
     3 import org.apache.spark.ml.regression.LinearRegression
     4 import org.apache.spark.sql.SparkSession
     5 
     6 /**
     7   * 线性回归
     8   * Created by zhen on 2018/11/12.
     9   */
    10 object LinearRegression {
    11   Logger.getLogger("org").setLevel(Level.WARN) // 设置日志级别
    12   def main(args: Array[String]) {
    13     val spark = SparkSession
    14       .builder()
    15       .appName("LinearRegression")
    16       .master("local[2]")
    17       .getOrCreate()
    18     val train_data = spark.sparkContext.textFile("E:/BDS/newsparkml/src/train.txt") // 加载数据
    19     val train_map_data = train_data.map{ row =>
    20         val split = row.split(",")
    21         (split(0).toDouble,split(1).toDouble,split(2).toDouble,split(3).toDouble,
    22           split(4).toDouble,split(5).toDouble,split(6).toDouble,split(7).toDouble)
    23       }
    24     val df = spark.sqlContext.createDataFrame(train_map_data)
    25     val colArray = Array("Population","Income","Illiteracy","LifeExp","HSGrad","Frost","Area")
    26     val train_df = df.toDF(colArray(0),colArray(1),colArray(2),colArray(3),"Murder",colArray(4),colArray(5),colArray(6))
    27     val assembler = new VectorAssembler()
    28       .setInputCols(colArray)
    29       .setOutputCol("features")
    30     val vectDF = assembler.transform(train_df)
    31     val weights = Array(0.8,0.2) //设置训练集和测试集的比例
    32     val split_data = vectDF.randomSplit(weights) // 拆分训练集和测试集
    33     // 创建模型对象
    34     val linearRegression = new LinearRegression()
    35       .setFeaturesCol("features")
    36       .setLabelCol("Murder")
    37       .setFitIntercept(true)
    38       .setMaxIter(10)
    39       .setRegParam(0.3)// 正则化
    40       .setElasticNetParam(0.8)
    41     // 训练模型
    42     val lrModel = linearRegression.fit(split_data(0))
    43     // 查看模型参数
    44     //lrModel.extractParamMap()
    45     println(s"Cofficients:${lrModel.coefficients} Intercept:${lrModel.intercept}")
    46     //模型评估
    47     val trainingSummary = lrModel.summary
    48     println(s"objectiveHistoryList:${trainingSummary.objectiveHistory.toList}")
    49     println(s"r2:${trainingSummary.r2}")
    50     // 预测
    51     val predictions = lrModel.transform(split_data(1))
    52     val predict_result = predictions.selectExpr("features","Murder","round(prediction,1) as prediction") // 保存一位小数
    53     println("训练集数据------------------------------真实值--预测值")
    54     predict_result.foreach(println(_))
    55   }
    56 }

    结果:

  • 相关阅读:
    AngularJS中的Provider们:Service和Factory等的区别
    解决Eclipse建立Maven项目后无法建立src/main/java资源文件夹的办法
    关于EL表达式不起作用的问题
    Tomcat+Nginx 负载均衡配置,Tomcat+Nginx集群,Tomcat集群
    Java WebService 简单实例
    火狐浏览器中表单内容在表单刷新时候不重置表单信息
    ie文本框内容不居中问题
    javascript call和apply方法
    javascript的词法作用域
    C++提高编程 deque容器
  • 原文地址:https://www.cnblogs.com/yszd/p/9952268.html
Copyright © 2011-2022 走看看