zoukankan      html  css  js  c++  java
  • Spark DecisionTreeClassifier 决策树分类

    1、概述

    决策树及树集(算法)是用于机器学习任务的分类和回归的流行方法。决策树被广泛使用,因为它们易于解释,处理分类特征,扩展到多类分类设置,不需要特征缩放,并且能够捕获非线性和特征交互。树集分类算法(例如随机森林和boosting)在分类和回归任务中表现最佳。
    spark.ml实现使用连续和分类特征,支持用于二元分类和多类分类以及用于回归的决策树。该实现按行对数据进行分区,从而允许对数百万甚至数十亿个实例进行分布式训练。

    2、输入和输出

    所有输出列都是可选的;要排除输出列,请将其对应的Param设置为空字符串。

    Input Columns

    Param nameType(s)DefaultDescription
    labelCol Double "label" Label to predict
    featuresCol Vector "features" Feature vector

    Output Columns

    Param nameType(s)DefaultDescriptionNotes
    predictionCol Double "prediction" Predicted label  
    rawPredictionCol Vector "rawPrediction" Vector of length # classes, with the counts of training instance labels at the tree node which makes the prediction Classification only
    probabilityCol Vector "probability" Vector of length # classes equal to rawPrediction normalized to a multinomial distribution Classification only
    varianceCol Double   The biased sample variance of prediction Regression only

    3、code

    package com.home.spark.ml
    
    import org.apache.spark.SparkConf
    import org.apache.spark.ml.Pipeline
    import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier}
    import org.apache.spark.ml.evaluation.{MulticlassClassificationEvaluator, RegressionEvaluator}
    import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer}
    import org.apache.spark.ml.linalg.{Vector, Vectors}
    import org.apache.spark.ml.regression.DecisionTreeRegressor
    import org.apache.spark.sql.{Dataset, Row, SparkSession}
    
    object Ex_DecisionTree {
      def main(args: Array[String]): Unit = {
        val conf: SparkConf = new SparkConf(true).setMaster("local[2]").setAppName("spark ml")
        val spark = SparkSession.builder().config(conf).getOrCreate()
    
        //rdd转换成df或者ds需要SparkSession实例的隐式转换
        //导入隐式转换,注意这里的spark不是包名,而是SparkSession的对象名
        import spark.implicits._
    
        val data = spark.sparkContext.textFile("input/iris.data.txt")
          .map(_.split(","))
          .map(a => Iris(
            Vectors.dense(a(0).toDouble, a(1).toDouble, a(2).toDouble, a(3).toDouble),
            a(4))
          ).toDF()
    
        data.createOrReplaceTempView("iris")
        val df = spark.sql("select * from iris")
        df.map(r => r(1) + " : " + r(0)).collect().take(10).foreach(println)
    
    
        ////对特征列和标签列进行索引转换
        val labelIndexer = new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(df)
        val featureIndexer = new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures")
          .setMaxCategories(4).fit(df)
    
    
        //决策树分类器
        val dtClassifier = new DecisionTreeClassifier().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures")
    
        //将预测的类别重新转成字符型
        val labelConverter = new IndexToString().setInputCol("prediction").setOutputCol("predictionLabel").setLabels(labelIndexer.labels)
    
        //将原数据集拆分成两个部分,一部分用于训练,一部分用于测试
        val Array(trainingData, testData): Array[Dataset[Row]] = df.randomSplit(Array(0.7,0.3))
    
        //建立工作流
        val pipeline = new Pipeline().setStages(Array(labelIndexer,featureIndexer,dtClassifier,labelConverter))
    
        //生成训练模型
        val modelDecisionTreeClassifier = pipeline.fit(trainingData)
    
        //预测
        val result = modelDecisionTreeClassifier.transform(testData)
    
        result.show(150,false)
    
        /**
          * 样本分为:正类样本和负类样本。
          * TP:被分类器正确分类的正类样本数。
          * TN: 被分类器正确分类的负类样本数。
          * FP: 被分类器错误分类的正类样本数。(本来是负,被预测为正) ---------->正
          * FN: 被分类器错误分类的负类样本数。 (本来是正, 被预测为负) ---------->负
          *
          * 准确率(Accuracy ACC)
          * 总样本数=TP+TN+FP+FN
          * ACC=(TP+TN)/(总样本数)
          * 该评价指标主要针对分类均匀的数据集。
          */
        val evaluator = new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction")
          .setMetricName("accuracy")
        val accuracy: Double = evaluator.evaluate(result)
    
        println("Accuracy = " + accuracy)
    
        /**
          * 精确率(Precision 查准率)
          * Precision = TP / (TP+ FP) 准确率,表示模型预测为正样本的样本中真正为正的比例
          */
        val evaluator2 = new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction")
          .setMetricName("weightedPrecision")
        val weightedPrecision: Double = evaluator2.evaluate(result)
    
        println("weightedPrecision = " + weightedPrecision)
    
        /**
          * 召回率(查全率)
          * Recall = TP /(TP + FN) 召回率,表示模型准确预测为正样本的数量占所有正样本数量的比例
          */
        val evaluator3 = new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction")
          .setMetricName("weightedRecall")
        val weightedRecall: Double = evaluator3.evaluate(result)
    
        println("weightedRecall = " + weightedRecall)
    
    
        val treeModel = modelDecisionTreeClassifier.stages(2).asInstanceOf[DecisionTreeClassificationModel]
        println("Learned classification tree model:
    " + treeModel.toDebugString)
    
        //决策树回归器
        val dtRegressor = new DecisionTreeRegressor().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures")
    
        val pipelineRegressor = new Pipeline()
          .setStages(Array(labelIndexer,featureIndexer,dtRegressor,labelConverter))
    
        val modelRegressor = pipelineRegressor.fit(trainingData)
        val result2 = modelRegressor.transform(testData)
    
        result2.show(150,false)
    
        //评估
        val regressionEvaluator = new RegressionEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction")
            .setMetricName("rmse")
        val rmse = regressionEvaluator.evaluate(result2)
        println("rmse = " + rmse)
        spark.stop()
      }
    }
    
    case class Iris(features: Vector, label: String)
  • 相关阅读:
    CF666E. Forensic Examination
    bzoj1396 识别子串
    bzoj2839 集合计数
    unknown
    Hibernate中一级缓存和二级缓存
    亲, 我们来再重申一遍"=="和"equals的区别
    BigDecimal类
    序列化详解
    利用简单的参数传递来实现单条查询的easyui-datagrid
    Oracl 动态执行表不可访问,本会话的自动统计被禁止
  • 原文地址:https://www.cnblogs.com/asker009/p/12403407.html
Copyright © 2011-2022 走看看