zoukankan      html  css  js  c++  java
  • Spark 决策树--回归模型

    package Spark_MLlib
    
    import org.apache.spark.ml.Pipeline
    import org.apache.spark.ml.evaluation.RegressionEvaluator
    import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer}
    import org.apache.spark.sql.SparkSession
    import org.apache.spark.ml.linalg.{Vector, Vectors}
    import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor}
    
    case class data_scheam(features:Vector,label:String)
    object 决策树__回归模型 {
          val spark=SparkSession.builder().master("local").getOrCreate()
          import spark.implicits._
      def main(args: Array[String]): Unit = {
         val data=spark.sparkContext.textFile("file:///home/soyo/桌面/spark编程测试数据/soyo2.txt")
                    .map(_.split(",")).map(x=>data_schema(Vectors.dense(x(0).toDouble,x(1).toDouble,x(2).toDouble,x(3).toDouble),x(4))).toDF()
           val labelIndexer=new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(data)
           val featuresIndexer=new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").setMaxCategories(4).fit(data)
          val labelCoverter=new IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(labelIndexer.labels)
          val Array(trainData,testData)=data.randomSplit(Array(0.7,0.3))
        //决策树回归模型构造设置
          val dtRegressor=new DecisionTreeRegressor().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures")
        //构造机器学习工作流
          val pipelineRegressor=new Pipeline().setStages(Array(labelIndexer,featuresIndexer,dtRegressor,labelCoverter))
        //训练决策树回归模型
          val modelRegressor=pipelineRegressor.fit(trainData)
         //进行预测
          val prediction=modelRegressor.transform(testData)
          prediction.show(150)
        //评估决策树回归模型
          val evaluatorRegressor=new RegressionEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction").setMetricName("rmse") //setMetricName:设置决定你的度量标准是均方根误差还是均方误差等,值可以为:rmse,mse,r2,mae
    val Root_Mean_Squared_Error=evaluatorRegressor.evaluate(prediction) println("均方根误差为: "+Root_Mean_Squared_Error) val treeModelRegressor=modelRegressor.stages(2).asInstanceOf[DecisionTreeRegressionModel] val schema_decisionTree=treeModelRegressor.toDebugString println("决策树分类模型的结构为: "+schema_decisionTree) } }
    Spark 源码:关于setMetricName("")
    @Since("2.0.0")
      override def evaluate(dataset: Dataset[_]): Double = {
        val schema = dataset.schema
        SchemaUtils.checkColumnTypes(schema, $(predictionCol), Seq(DoubleType, FloatType))
        SchemaUtils.checkNumericType(schema, $(labelCol))
    
        val predictionAndLabels = dataset
          .select(col($(predictionCol)).cast(DoubleType), col($(labelCol)).cast(DoubleType))
          .rdd
          .map { case Row(prediction: Double, label: Double) => (prediction, label) }
        val metrics = new RegressionMetrics(predictionAndLabels)
        val metric = $(metricName) match {
          case "rmse" => metrics.rootMeanSquaredError
          case "mse" => metrics.meanSquaredError
          case "r2" => metrics.r2
          case "mae" => metrics.meanAbsoluteError
        }
        metric
    }
    

    结果:

    +-----------------+------+------------+-----------------+----------+--------------+
    |         features| label|indexedLabel|  indexedFeatures|prediction|predictedLabel|
    +-----------------+------+------------+-----------------+----------+--------------+
    |[4.6,3.1,1.5,0.2]|hadoop|         1.0|[4.6,3.1,1.5,0.2]|       1.0|        hadoop|
    |[4.6,3.4,1.4,0.3]|hadoop|         1.0|[4.6,3.4,1.4,0.3]|       1.0|        hadoop|
    |[4.7,3.2,1.3,0.2]|hadoop|         1.0|[4.7,3.2,1.3,0.2]|       1.0|        hadoop|
    |[4.8,3.0,1.4,0.1]|hadoop|         1.0|[4.8,3.0,1.4,0.1]|       1.0|        hadoop|
    |[5.1,3.3,1.7,0.5]|hadoop|         1.0|[5.1,3.3,1.7,0.5]|       1.0|        hadoop|
    |[5.1,3.7,1.5,0.4]|hadoop|         1.0|[5.1,3.7,1.5,0.4]|       1.0|        hadoop|
    |[5.4,3.9,1.3,0.4]|hadoop|         1.0|[5.4,3.9,1.3,0.4]|       1.0|        hadoop|
    |[5.5,2.3,4.0,1.3]| spark|         0.0|[5.5,2.3,4.0,1.3]|       0.0|         spark|
    |[5.5,3.5,1.3,0.2]|hadoop|         1.0|[5.5,3.5,1.3,0.2]|       1.0|        hadoop|
    |[5.6,2.7,4.2,1.3]| spark|         0.0|[5.6,2.7,4.2,1.3]|       0.0|         spark|
    |[5.6,3.0,4.1,1.3]| spark|         0.0|[5.6,3.0,4.1,1.3]|       0.0|         spark|
    |[5.6,3.0,4.5,1.5]| spark|         0.0|[5.6,3.0,4.5,1.5]|       0.0|         spark|
    |[5.7,2.6,3.5,1.0]| spark|         0.0|[5.7,2.6,3.5,1.0]|       0.0|         spark|
    |[5.7,4.4,1.5,0.4]|hadoop|         1.0|[5.7,4.4,1.5,0.4]|       1.0|        hadoop|
    |[5.8,2.7,3.9,1.2]| spark|         0.0|[5.8,2.7,3.9,1.2]|       0.0|         spark|
    |[5.8,2.7,4.1,1.0]| spark|         0.0|[5.8,2.7,4.1,1.0]|       0.0|         spark|
    |[5.8,2.8,5.1,2.4]| Scala|         2.0|[5.8,2.8,5.1,2.4]|       2.0|         Scala|
    |[5.8,4.0,1.2,0.2]|hadoop|         1.0|[5.8,4.0,1.2,0.2]|       1.0|        hadoop|
    |[5.9,3.0,4.2,1.5]| spark|         0.0|[5.9,3.0,4.2,1.5]|       0.0|         spark|
    |[5.9,3.0,5.1,1.8]| Scala|         2.0|[5.9,3.0,5.1,1.8]|       2.0|         Scala|
    |[5.9,3.2,4.8,1.8]| spark|         0.0|[5.9,3.2,4.8,1.8]|       2.0|         Scala|
    |[6.1,2.6,5.6,1.4]| Scala|         2.0|[6.1,2.6,5.6,1.4]|       2.0|         Scala|
    |[6.1,2.8,4.0,1.3]| spark|         0.0|[6.1,2.8,4.0,1.3]|       0.0|         spark|
    |[6.3,2.9,5.6,1.8]| Scala|         2.0|[6.3,2.9,5.6,1.8]|       2.0|         Scala|
    |[6.3,3.4,5.6,2.4]| Scala|         2.0|[6.3,3.4,5.6,2.4]|       2.0|         Scala|
    |[6.4,2.7,5.3,1.9]| Scala|         2.0|[6.4,2.7,5.3,1.9]|       2.0|         Scala|
    |[6.4,3.1,5.5,1.8]| Scala|         2.0|[6.4,3.1,5.5,1.8]|       2.0|         Scala|
    |[6.4,3.2,4.5,1.5]| spark|         0.0|[6.4,3.2,4.5,1.5]|       0.0|         spark|
    |[6.5,2.8,4.6,1.5]| spark|         0.0|[6.5,2.8,4.6,1.5]|       0.0|         spark|
    |[6.5,3.0,5.5,1.8]| Scala|         2.0|[6.5,3.0,5.5,1.8]|       2.0|         Scala|
    |[6.7,3.0,5.2,2.3]| Scala|         2.0|[6.7,3.0,5.2,2.3]|       2.0|         Scala|
    |[6.7,3.1,4.7,1.5]| spark|         0.0|[6.7,3.1,4.7,1.5]|       0.0|         spark|
    |[6.8,3.0,5.5,2.1]| Scala|         2.0|[6.8,3.0,5.5,2.1]|       2.0|         Scala|
    |[6.9,3.1,5.4,2.1]| Scala|         2.0|[6.9,3.1,5.4,2.1]|       2.0|         Scala|
    |[7.0,3.2,4.7,1.4]| spark|         0.0|[7.0,3.2,4.7,1.4]|       0.0|         spark|
    |[7.1,3.0,5.9,2.1]| Scala|         2.0|[7.1,3.0,5.9,2.1]|       2.0|         Scala|
    |[7.2,3.0,5.8,1.6]| Scala|         2.0|[7.2,3.0,5.8,1.6]|       0.0|         spark|
    |[7.2,3.2,6.0,1.8]| Scala|         2.0|[7.2,3.2,6.0,1.8]|       2.0|         Scala|
    |[7.2,3.6,6.1,2.5]| Scala|         2.0|[7.2,3.6,6.1,2.5]|       2.0|         Scala|
    |[7.4,2.8,6.1,1.9]| Scala|         2.0|[7.4,2.8,6.1,1.9]|       2.0|         Scala|
    |[7.7,2.6,6.9,2.3]| Scala|         2.0|[7.7,2.6,6.9,2.3]|       2.0|         Scala|
    |[7.7,2.8,6.7,2.0]| Scala|         2.0|[7.7,2.8,6.7,2.0]|       2.0|         Scala|
    +-----------------+------+------------+-----------------+----------+--------------+

    均方根误差为: 0.43643578047198484
    决策树分类模型的结构为: DecisionTreeRegressionModel (uid=dtr_6015411b1a3d) of depth 4 with 11 nodes
      If (feature 3 <= 1.7)
       If (feature 2 <= 1.9)
        Predict: 1.0
       Else (feature 2 > 1.9)
        If (feature 2 <= 4.9)
         If (feature 3 <= 1.6)
          Predict: 0.0
         Else (feature 3 > 1.6)
          Predict: 2.0
        Else (feature 2 > 4.9)
         If (feature 3 <= 1.5)
          Predict: 2.0
         Else (feature 3 > 1.5)
          Predict: 0.0
      Else (feature 3 > 1.7)
       Predict: 2.0

  • 相关阅读:
    find the safest road HDU
    分页存储过程
    .NET Core与.NET Framework、Mono之间的关系
    winForm开发
    面试题目总结
    sqlserver锁表、解锁、查看锁表
    架构漫谈(四):如何做好架构之架构切分
    多线程讲解
    递归菜单简单应用
    杂记
  • 原文地址:https://www.cnblogs.com/soyo/p/7793664.html
Copyright © 2011-2022 走看看