zoukankan      html  css  js  c++  java
  • 掌握Spark机器学习库-08.7-决策树算法实现分类

    数据集

    iris.data

    数据集概览

    代码

    package org.apache.spark.examples.examplesforml
    
    import org.apache.spark.SparkConf
    import org.apache.spark.ml.classification.{DecisionTreeClassifier, NaiveBayes}
    import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
    import org.apache.spark.ml.feature.VectorAssembler
    import org.apache.spark.sql.SparkSession
    
    import scala.util.Random
    
    object DeTree {
      def main(args: Array[String]): Unit = {
    
        val conf = new SparkConf().setMaster("local").setAppName("iris")
        val spark = SparkSession.builder().config(conf).getOrCreate()
        spark.sparkContext.setLogLevel("WARN") ///日志级别
    
        val file = spark.read.format("csv").load("D:\8-6决策树\iris.data")
        //file.show()
    
        import spark.implicits._
        val random = new Random()
        val data = file.map(row =>{
          val label =  row.getString(4) match {
            case "Iris-setosa" => 0
            case "Iris-versicolor" => 1
            case "Iris-virginica" => 2
          }
    
          (row.getString(0).toDouble,
            row.getString(1).toDouble,
            row.getString(2).toDouble,
            row.getString(3).toDouble,
            label,
            random.nextDouble())
        }).toDF("_c0","_c1","_c2","_c3","label","rand").sort("rand")//.where("label = 1 or label = 0")
    
        val assembler = new VectorAssembler().setInputCols(Array("_c0","_c1","_c2","_c3")).setOutputCol("features")
    
        val dataset = assembler.transform(data)
        val Array(train,test) = dataset.randomSplit(Array(0.8,0.2))
    
        val dt = new DecisionTreeClassifier().setFeaturesCol("features").setLabelCol("label")
        val model = dt.fit(train)
        val result = model.transform(test)
        result.show()
    
        val evaluator = new MulticlassClassificationEvaluator()
          .setLabelCol("label")
          .setPredictionCol("prediction")
          .setMetricName("accuracy")
        val accuracy = evaluator.evaluate(result)
        println(s"""accuracy is $accuracy""")
      }
    }

    输出结果:

  • 相关阅读:
    程序清单 8-8 exec函数实例,a.out是程序8-9产生的可执行程序
    程序清单8-9 回送所有命令行参数和所有环境字符串
    程序清单8-3 8-4 演示不同的exit值
    C和指针 3.9作用域、存储类型示例
    程序4-6 utime函数实例
    程序4-5 打开一个文件,然后unlink
    C和指针笔记 3.8 static关键字
    C和指针笔记 3.7 存储类型
    C和指针笔记 3.6链接属性
    python爬虫<urlopen error [Errno 10061] >
  • 原文地址:https://www.cnblogs.com/moonlightml/p/9789707.html
Copyright © 2011-2022 走看看