zoukankan      html  css  js  c++  java
  • Spark 决策树--分类模型

    package Spark_MLlib
    
    import org.apache.spark.ml.Pipeline
    import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier}
    import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
    import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer}
    import org.apache.spark.ml.linalg.{Vector, Vectors}
    import org.apache.spark.mllib.tree.DecisionTree
    import org.apache.spark.sql.SparkSession
    
    /**
      * Created by soyo on 17-11-5.
      */
    case class data_schemas(features:Vector,label:String)
    object 决策树 {
       val spark=SparkSession.builder().master("local").appName("决策树").getOrCreate()
      import spark.implicits._
      def main(args: Array[String]): Unit = {
    
        val source_DF=spark.sparkContext.textFile("file:///home/soyo/桌面/spark编程测试数据/soyo2.txt")
                        .map(_.split(",")).map(x=>data_schemas(Vectors.dense(x(0).toDouble,x(1).toDouble,x(2).toDouble,x(3).toDouble),x(4))).toDF()
            source_DF.createOrReplaceTempView("decisonTree")
        val DF=spark.sql("select * from decisonTree")
            DF.show()
        //分别获取标签列和特征列,进行索引和重命名(索引的目的是将字符串label数值化方便机器学习算法学习)
        val lableIndexer=new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(DF)
        val featureIndexer= new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").setMaxCategories(4).fit(DF)
        val labelConverter= new IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(lableIndexer.labels)
        // 训练数据和测试数据
        val Array(trainData,testData)=DF.randomSplit(Array(0.7,0.3))
        val decisionTreeClassifier=new DecisionTreeClassifier().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures")
        //构建机器学习工作流
        val dt_pipeline=new Pipeline().setStages(Array(lableIndexer,featureIndexer,decisionTreeClassifier,labelConverter))
        val dt_model=dt_pipeline.fit(trainData)
        //进行预测
        val dtprediction=dt_model.transform(testData)
        dtprediction.show(150)
        //评估决策树模型
        val evaluatorClassifier=new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction").setMetricName("accuracy")
        val accuracy=evaluatorClassifier.evaluate(dtprediction)
        println("准确率为: "+accuracy)
        val error=1-accuracy
        println("错误率为: "+error)
        val treeModelClassifier=dt_model.stages(2).asInstanceOf[DecisionTreeClassificationModel]
        val schema_DecisionTree=treeModelClassifier.toDebugString
        println("决策树的模型结构为: "+schema_DecisionTree)
    
      }
    }

    结果为:

    +-----------------+------+
    |         features| label|
    +-----------------+------+
    |[5.1,3.5,1.4,0.2]|hadoop|
    |[4.9,3.0,1.4,0.2]|hadoop|
    |[4.7,3.2,1.3,0.2]|hadoop|
    |[4.6,3.1,1.5,0.2]|hadoop|
    |[5.0,3.6,1.4,0.2]|hadoop|
    |[5.4,3.9,1.7,0.4]|hadoop|
    |[4.6,3.4,1.4,0.3]|hadoop|
    |[5.0,3.4,1.5,0.2]|hadoop|
    |[4.4,2.9,1.4,0.2]|hadoop|
    |[4.9,3.1,1.5,0.1]|hadoop|
    |[5.4,3.7,1.5,0.2]|hadoop|
    |[4.8,3.4,1.6,0.2]|hadoop|
    |[4.8,3.0,1.4,0.1]|hadoop|
    |[4.3,3.0,1.1,0.1]|hadoop|
    |[5.8,4.0,1.2,0.2]|hadoop|
    |[5.7,4.4,1.5,0.4]|hadoop|
    |[5.4,3.9,1.3,0.4]|hadoop|
    |[5.1,3.5,1.4,0.3]|hadoop|
    |[5.7,3.8,1.7,0.3]|hadoop|
    |[5.1,3.8,1.5,0.3]|hadoop|
    +-----------------+------+
    only showing top 20 rows

    +-----------------+------+------------+-----------------+--------------+-------------+----------+--------------+
    |         features| label|indexedLabel|  indexedFeatures| rawPrediction|  probability|prediction|predictedLabel|
    +-----------------+------+------------+-----------------+--------------+-------------+----------+--------------+
    |[4.4,3.0,1.3,0.2]|hadoop|         1.0|[4.4,3.0,1.3,0.2]|[0.0,36.0,0.0]|[0.0,1.0,0.0]|       1.0|        hadoop|
    |[4.6,3.4,1.4,0.3]|hadoop|         1.0|[4.6,3.4,1.4,0.3]|[0.0,36.0,0.0]|[0.0,1.0,0.0]|       1.0|        hadoop|
    |[4.6,3.6,1.0,0.2]|hadoop|         1.0|[4.6,3.6,1.0,0.2]|[0.0,36.0,0.0]|[0.0,1.0,0.0]|       1.0|        hadoop|
    |[4.9,2.4,3.3,1.0]| spark|         0.0|[4.9,2.4,3.3,1.0]| [0.0,0.0,1.0]|[0.0,0.0,1.0]|       2.0|         Scala|
    |[5.0,2.0,3.5,1.0]| spark|         0.0|[5.0,2.0,3.5,1.0]| [1.0,0.0,0.0]|[1.0,0.0,0.0]|       0.0|         spark|
    |[5.0,2.3,3.3,1.0]| spark|         0.0|[5.0,2.3,3.3,1.0]|[29.0,0.0,0.0]|[1.0,0.0,0.0]|       0.0|         spark|
    |[5.0,3.2,1.2,0.2]|hadoop|         1.0|[5.0,3.2,1.2,0.2]|[0.0,36.0,0.0]|[0.0,1.0,0.0]|       1.0|        hadoop|
    |[5.0,3.3,1.4,0.2]|hadoop|         1.0|[5.0,3.3,1.4,0.2]|[0.0,36.0,0.0]|[0.0,1.0,0.0]|       1.0|        hadoop|
    |[5.0,3.4,1.6,0.4]|hadoop|         1.0|[5.0,3.4,1.6,0.4]|[0.0,36.0,0.0]|[0.0,1.0,0.0]|       1.0|        hadoop|
    |[5.0,3.6,1.4,0.2]|hadoop|         1.0|[5.0,3.6,1.4,0.2]|[0.0,36.0,0.0]|[0.0,1.0,0.0]|       1.0|        hadoop|
    |[5.1,3.5,1.4,0.2]|hadoop|         1.0|[5.1,3.5,1.4,0.2]|[0.0,36.0,0.0]|[0.0,1.0,0.0]|       1.0|        hadoop|
    |[5.1,3.7,1.5,0.4]|hadoop|         1.0|[5.1,3.7,1.5,0.4]|[0.0,36.0,0.0]|[0.0,1.0,0.0]|       1.0|        hadoop|
    |[5.2,3.4,1.4,0.2]|hadoop|         1.0|[5.2,3.4,1.4,0.2]|[0.0,36.0,0.0]|[0.0,1.0,0.0]|       1.0|        hadoop|
    |[5.2,4.1,1.5,0.1]|hadoop|         1.0|[5.2,4.1,1.5,0.1]|[0.0,36.0,0.0]|[0.0,1.0,0.0]|       1.0|        hadoop|
    |[5.4,3.0,4.5,1.5]| spark|         0.0|[5.4,3.0,4.5,1.5]|[29.0,0.0,0.0]|[1.0,0.0,0.0]|       0.0|         spark|
    |[5.4,3.9,1.7,0.4]|hadoop|         1.0|[5.4,3.9,1.7,0.4]|[0.0,36.0,0.0]|[0.0,1.0,0.0]|       1.0|        hadoop|
    |[5.5,2.4,3.7,1.0]| spark|         0.0|[5.5,2.4,3.7,1.0]|[29.0,0.0,0.0]|[1.0,0.0,0.0]|       0.0|         spark|
    |[5.5,2.4,3.8,1.1]| spark|         0.0|[5.5,2.4,3.8,1.1]|[29.0,0.0,0.0]|[1.0,0.0,0.0]|       0.0|         spark|
    |[5.5,2.5,4.0,1.3]| spark|         0.0|[5.5,2.5,4.0,1.3]|[29.0,0.0,0.0]|[1.0,0.0,0.0]|       0.0|         spark|
    |[5.5,2.6,4.4,1.2]| spark|         0.0|[5.5,2.6,4.4,1.2]|[29.0,0.0,0.0]|[1.0,0.0,0.0]|       0.0|         spark|
    |[5.5,4.2,1.4,0.2]|hadoop|         1.0|[5.5,4.2,1.4,0.2]|[0.0,36.0,0.0]|[0.0,1.0,0.0]|       1.0|        hadoop|
    |[5.6,2.5,3.9,1.1]| spark|         0.0|[5.6,2.5,3.9,1.1]|[29.0,0.0,0.0]|[1.0,0.0,0.0]|       0.0|         spark|
    |[5.6,2.7,4.2,1.3]| spark|         0.0|[5.6,2.7,4.2,1.3]|[29.0,0.0,0.0]|[1.0,0.0,0.0]|       0.0|         spark|
    |[5.6,3.0,4.1,1.3]| spark|         0.0|[5.6,3.0,4.1,1.3]|[29.0,0.0,0.0]|[1.0,0.0,0.0]|       0.0|         spark|
    |[5.7,2.6,3.5,1.0]| spark|         0.0|[5.7,2.6,3.5,1.0]|[29.0,0.0,0.0]|[1.0,0.0,0.0]|       0.0|         spark|
    |[5.8,2.6,4.0,1.2]| spark|         0.0|[5.8,2.6,4.0,1.2]|[29.0,0.0,0.0]|[1.0,0.0,0.0]|       0.0|         spark|
    |[5.8,4.0,1.2,0.2]|hadoop|         1.0|[5.8,4.0,1.2,0.2]|[0.0,36.0,0.0]|[0.0,1.0,0.0]|       1.0|        hadoop|
    |[6.1,2.6,5.6,1.4]| Scala|         2.0|[6.1,2.6,5.6,1.4]|[29.0,0.0,0.0]|[1.0,0.0,0.0]|       0.0|         spark|
    |[6.2,2.2,4.5,1.5]| spark|         0.0|[6.2,2.2,4.5,1.5]| [0.0,0.0,1.0]|[0.0,0.0,1.0]|       2.0|         Scala|
    |[6.2,3.4,5.4,2.3]| Scala|         2.0|[6.2,3.4,5.4,2.3]|[0.0,0.0,31.0]|[0.0,0.0,1.0]|       2.0|         Scala|
    |[6.3,2.5,5.0,1.9]| Scala|         2.0|[6.3,2.5,5.0,1.9]|[0.0,0.0,31.0]|[0.0,0.0,1.0]|       2.0|         Scala|
    |[6.3,2.8,5.1,1.5]| Scala|         2.0|[6.3,2.8,5.1,1.5]|[29.0,0.0,0.0]|[1.0,0.0,0.0]|       0.0|         spark|
    |[6.4,2.8,5.6,2.1]| Scala|         2.0|[6.4,2.8,5.6,2.1]|[0.0,0.0,31.0]|[0.0,0.0,1.0]|       2.0|         Scala|
    |[6.4,2.8,5.6,2.2]| Scala|         2.0|[6.4,2.8,5.6,2.2]|[0.0,0.0,31.0]|[0.0,0.0,1.0]|       2.0|         Scala|
    |[6.4,3.2,4.5,1.5]| spark|         0.0|[6.4,3.2,4.5,1.5]|[29.0,0.0,0.0]|[1.0,0.0,0.0]|       0.0|         spark|
    |[6.4,3.2,5.3,2.3]| Scala|         2.0|[6.4,3.2,5.3,2.3]|[0.0,0.0,31.0]|[0.0,0.0,1.0]|       2.0|         Scala|
    |[6.5,2.8,4.6,1.5]| spark|         0.0|[6.5,2.8,4.6,1.5]|[29.0,0.0,0.0]|[1.0,0.0,0.0]|       0.0|         spark|
    |[6.6,2.9,4.6,1.3]| spark|         0.0|[6.6,2.9,4.6,1.3]|[29.0,0.0,0.0]|[1.0,0.0,0.0]|       0.0|         spark|
    |[6.6,3.0,4.4,1.4]| spark|         0.0|[6.6,3.0,4.4,1.4]|[29.0,0.0,0.0]|[1.0,0.0,0.0]|       0.0|         spark|
    |[6.8,3.2,5.9,2.3]| Scala|         2.0|[6.8,3.2,5.9,2.3]|[0.0,0.0,31.0]|[0.0,0.0,1.0]|       2.0|         Scala|
    |[6.9,3.1,4.9,1.5]| spark|         0.0|[6.9,3.1,4.9,1.5]|[29.0,0.0,0.0]|[1.0,0.0,0.0]|       0.0|         spark|
    |[6.9,3.2,5.7,2.3]| Scala|         2.0|[6.9,3.2,5.7,2.3]|[0.0,0.0,31.0]|[0.0,0.0,1.0]|       2.0|         Scala|
    |[7.2,3.0,5.8,1.6]| Scala|         2.0|[7.2,3.0,5.8,1.6]|[29.0,0.0,0.0]|[1.0,0.0,0.0]|       0.0|         spark|
    |[7.2,3.2,6.0,1.8]| Scala|         2.0|[7.2,3.2,6.0,1.8]|[0.0,0.0,31.0]|[0.0,0.0,1.0]|       2.0|         Scala|
    |[7.6,3.0,6.6,2.1]| Scala|         2.0|[7.6,3.0,6.6,2.1]|[0.0,0.0,31.0]|[0.0,0.0,1.0]|       2.0|         Scala|
    |[7.7,3.0,6.1,2.3]| Scala|         2.0|[7.7,3.0,6.1,2.3]|[0.0,0.0,31.0]|[0.0,0.0,1.0]|       2.0|         Scala|
    |[7.7,3.8,6.7,2.2]| Scala|         2.0|[7.7,3.8,6.7,2.2]|[0.0,0.0,31.0]|[0.0,0.0,1.0]|       2.0|         Scala|
    |[7.9,3.8,6.4,2.0]| Scala|         2.0|[7.9,3.8,6.4,2.0]|[0.0,0.0,31.0]|[0.0,0.0,1.0]|       2.0|         Scala|
    +-----------------+------+------------+-----------------+--------------+-------------+----------+--------------+

    准确率为: 0.8958333333333334
    错误率为: 0.10416666666666663
    决策树的结构为: DecisionTreeClassificationModel (uid=dtc_218264842cd2) of depth 5 with 15 nodes
      If (feature 2 <= 1.9)
       Predict: 1.0
      Else (feature 2 > 1.9)
       If (feature 3 <= 1.7)
        If (feature 0 <= 4.9)
         Predict: 2.0
        Else (feature 0 > 4.9)
         If (feature 1 <= 2.2)
          If (feature 2 <= 4.0)
           Predict: 0.0
          Else (feature 2 > 4.0)
           Predict: 2.0
         Else (feature 1 > 2.2)
          Predict: 0.0
       Else (feature 3 > 1.7)
        If (feature 2 <= 4.8)
         If (feature 0 <= 5.9)
          Predict: 0.0
         Else (feature 0 > 5.9)
          Predict: 2.0
        Else (feature 2 > 4.8)
         Predict: 2.0


  • 相关阅读:
    程序打包
    MFC AfxMessageBox默认标题修改
    Json
    agsXMPP
    xmpp
    afxcomctl32.h与afxcomctl32.inl报错
    jQuery使用
    EChart使用
    C++ tinyXML使用
    electron之Windows下使用 html js css 开发桌面应用程序
  • 原文地址:https://www.cnblogs.com/soyo/p/7792977.html
Copyright © 2011-2022 走看看