zoukankan      html  css  js  c++  java
  • Spark中决策树源码分析

    1.Example

    使用Spark MLlib中决策树分类器API,训练出一个决策树模型,使用Python开发。

    """
    Decision Tree Classification Example.
    """
    from __future__ import print_function
    
    from pyspark import SparkContext
    from pyspark.mllib.tree import DecisionTree, DecisionTreeModel
    from pyspark.mllib.util import MLUtils
    
    if __name__ == "__main__":
    
        sc = SparkContext(appName="PythonDecisionTreeClassificationExample")
    
        # 加载和解析数据文件为RDD
        dataPath = "/home/zhb/Desktop/work/DecisionTreeShareProject/app/sample_libsvm_data.txt"
        print(dataPath)
    
        data = MLUtils.loadLibSVMFile(sc,dataPath)
        # 将数据集分割为训练数据集和测试数据集
        (trainingData,testData) = data.randomSplit([0.7,0.3])
        print("train data count: " + str(trainingData.count()))
        print("test data count : " + str(testData.count()))
    
        # 训练决策树分类器
        # categoricalFeaturesInfo 为空,表示所有的特征均为连续值
        model = DecisionTree.trainClassifier(trainingData, numClasses=2, categoricalFeaturesInfo={},
                                             impurity='gini', maxDepth=5, maxBins=32)
    
        # 测试数据集上预测
        predictions = model.predict(testData.map(lambda x: x.features))
        # 打包真实值与预测值
        labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions)
        # 统计预测错误的样本的频率
        testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(testData.count())
        print('Decision Tree Test Error = %5.3f%%'%(testErr*100))
        print("Decision Tree Learned classifiction tree model : ")
        print(model.toDebugString())
    
        # 保存和加载训练好的模型
        modelPath = "/home/zhb/Desktop/work/DecisionTreeShareProject/app/myDecisionTreeClassificationModel"
        model.save(sc, modelPath)
        sameModel = DecisionTreeModel.load(sc, modelPath)
    

    2.决策树源码分析

    决策树分类器API为DecisionTree.trainClassifier,进入源码分析。

    源码文件所在路径为,spark-1.6/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala。

      @Since("1.1.0")
      def trainClassifier(
          input: RDD[LabeledPoint],
          numClasses: Int,
          categoricalFeaturesInfo: Map[Int, Int],
          impurity: String,
          maxDepth: Int,
          maxBins: Int): DecisionTreeModel = {
        val impurityType = Impurities.fromString(impurity)
        train(input, Classification, impurityType, maxDepth, numClasses, maxBins, Sort,
          categoricalFeaturesInfo)
      }
    

    训练出一个分类器,然后调用了train方法。

      @Since("1.0.0")
      def train(
          input: RDD[LabeledPoint],
          algo: Algo,
          impurity: Impurity,
          maxDepth: Int,
          numClasses: Int,
          maxBins: Int,
          quantileCalculationStrategy: QuantileStrategy,
          categoricalFeaturesInfo: Map[Int, Int]): DecisionTreeModel = {
        val strategy = new Strategy(algo, impurity, maxDepth, numClasses, maxBins,
          quantileCalculationStrategy, categoricalFeaturesInfo)
        new DecisionTree(strategy).run(input)
      }
    

    train方法首先将模型类型(分类或者回归)、信息增益指标、决策树深度、分类数目、最大切分箱子数等参数封装为Strategy,然后新建一个DecisionTree对象,并调用run方法。

    @Since("1.0.0")
    class DecisionTree private[spark] (private val strategy: Strategy, private val seed: Int)
      extends Serializable with Logging {
    
      /**
       * @param strategy The configuration parameters for the tree algorithm which specify the type
       *                 of decision tree (classification or regression), feature type (continuous,
       *                 categorical), depth of the tree, quantile calculation strategy, etc.
       */
      @Since("1.0.0")
      def this(strategy: Strategy) = this(strategy, seed = 0)
    
      strategy.assertValid()
    
      /**
       * Method to train a decision tree model over an RDD
       *
       * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
       * @return DecisionTreeModel that can be used for prediction.
       */
      @Since("1.2.0")
      def run(input: RDD[LabeledPoint]): DecisionTreeModel = {
        val rf = new RandomForest(strategy, numTrees = 1, featureSubsetStrategy = "all", seed = seed)
        val rfModel = rf.run(input)
        rfModel.trees(0)
      }
    }
    

    run方法中首先新建一个RandomForest对象,将strategy、决策树数目设置为1,子集选择策略为"all"传递给RandomForest对象,然后调用RandomForest中的run方法,最后返回随机森林模型中的第一棵决策树。

    也就是,决策树模型使用了随机森林模型进行训练,将决策树数目设置为1,然后将随机森林模型中的第一棵决策树作为结果,返回作为决策树训练模型。

    3.随机森林源码分析

    随机森林的源码文件所在路径为,spark-1.6/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala。

    private class RandomForest (
        private val strategy: Strategy,
        private val numTrees: Int,
        featureSubsetStrategy: String,
        private val seed: Int)
      extends Serializable with Logging {
    
      strategy.assertValid()
      require(numTrees > 0, s"RandomForest requires numTrees > 0, but was given numTrees = $numTrees.")
      require(RandomForest.supportedFeatureSubsetStrategies.contains(featureSubsetStrategy)
        || Try(featureSubsetStrategy.toInt).filter(_ > 0).isSuccess
        || Try(featureSubsetStrategy.toDouble).filter(_ > 0).filter(_ <= 1.0).isSuccess,
        s"RandomForest given invalid featureSubsetStrategy: $featureSubsetStrategy." +
        s" Supported values: ${NewRFParams.supportedFeatureSubsetStrategies.mkString(", ")}," +
        s" (0.0-1.0], [1-n].")
    
      /**
       * Method to train a decision tree model over an RDD
       *
       * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
       * @return RandomForestModel that can be used for prediction.
       */
      def run(input: RDD[LabeledPoint]): RandomForestModel = {
        val trees: Array[NewDTModel] = NewRandomForest.run(input.map(_.asML), strategy, numTrees,
          featureSubsetStrategy, seed.toLong, None)
        new RandomForestModel(strategy.algo, trees.map(_.toOld))
      }
    
    }
    

    在该文件开头,通过"import org.apache.spark.ml.tree.impl.{RandomForest => NewRandomForest}"将ml中的RandomForest引入,重新命名为NewRandomForest。

    在RandomForest.run方法中,首先新建NewRandomForest模型,并调用该类的run方法,然后将生成的trees作为新建RandomForestModel的入参。

    NewRandomForest,源码文件所在路径为,spark-1.6/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala。

    由于涉及代码量较大,因此无法将代码展开,run方法主要有如下调用。

    run方法
    
    --->1. val metadata = DecisionTreeMetadata.buildMetadata(retaggedInput, strategy, numTrees,featureSubsetStrategy) # 对输入数据建立元数据
    
    --->2. val splits = findSplits(retaggedInput, metadata, seed) # 对元数据中的特征进行切分
    
        --->2.1 计算采样率,对输入样本进行采样
        
        --->2.2 findSplitsBySorting(sampledInput, metadata, continuousFeatures) # 对采样后的样本中的特征进行切分
        
            --->2.2.1 val thresholds = findSplitsForContinuousFeature(samples, metadata, idx) # 针对连续型特征
            
            --->2.2.2 val categories = extractMultiClassCategories(splitIndex + 1, featureArity) # 针对分类型特征,且特征无序
            
            --->2.2.3 Array.empty[Split] # 针对分类型特征,且特征有序,训练时直接构造即可
    
    --->3. val treeInput = TreePoint.convertToTreeRDD(retaggedInput, splits, metadata) # 将输入数据转换为树形数据
    
        --->3.1 input.map { x => TreePoint.labeledPointToTreePoint(x, thresholds, featureArity) # 将LabeledPoint数据转换为TreePoint数据
        
        --->3.2 arr(featureIndex) = findBin(featureIndex, labeledPoint, featureArity(featureIndex), thresholds(featureIndex)) # 在(labeledPoint,feature)中找出一个离散值
    
    --->4. val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, strategy.subsamplingRate, numTrees,withReplacement, seed) # 对输入数据进行采样
    
        --->4.1 convertToBaggedRDDSamplingWithReplacement(input, subsamplingRate, numSubsamples, seed) #有放回采样
    
        --->4.2 convertToBaggedRDDWithoutSampling(input) # 样本数为1,采样率为100%
    
        --->4.3 convertToBaggedRDDSamplingWithoutReplacement(input, subsamplingRate, numSubsamples, seed) # 无放回采样
    
    --->5. val (nodesForGroup, treeToNodeToIndexInfo) = RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage,metadata, rng) # 取得每棵树所有需要切分的结点
    
        --->5.1 val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) { Some(SamplingUtils.reservoirSampleAndCount(Range(0, metadata.numFeatures).iterator, metadata.numFeaturesPerNode, rng.nextLong())._1)} # 如果需要子采样,选择特征子集
        
        --->5.2 val nodeMemUsage = RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L # 计算添加这个结点之后,是否有足够的内存
    
    --->6. RandomForest.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup, treeToNodeToIndexInfo, splits, nodeQueue, timer, nodeIdCache) # 找出最优切分点
    
        --->6.1 val (split: Split, stats: ImpurityStats) = binsToBestSplit(aggStats, splits, featuresForNode, nodes(nodeIndex)) #找出每个结点最好的切分
    
    --->7. new DecisionTreeClassificationModel(uid, rootNode.toNode, numFeatures, strategy.getNumClasses) # 返回决策树分类模型
    

    4.Reference

    spark mllib中的随机森林算法,实现源码以及使用介绍

    Spark MLlib - Decision Tree源码分析

    Spark MLlib机器学习:算法、源码及实战详解

  • 相关阅读:
    盘点 2011 年五款开源的 iPhone/Android 游戏
    你值得安装的 7 个很酷的 CyanogenMod 7 主题
    当 iOS 游戏开发像做份沙拉那么简单
    Mono for Android 4.0, 用 C# 开发 Android 应用
    Windows 8 Beta 应用大赛启动 现已可以上传作品
    10 个实验性的 JS/CSS3 编程技术
    关于Android图形系统的一些事实真相
    Mac 平台上给开发者/设计者的17个有用的 App
    惠普宣布保留webOS转型为开放源代码社区
    Windows 8来者不善,准备接招
  • 原文地址:https://www.cnblogs.com/zhbzz2007/p/5920779.html
Copyright © 2011-2022 走看看