zoukankan      html  css  js  c++  java
  • Spark ML源码分析之四 树

            之前我们讲过,在Spark ML中所有的机器学习模型都是以参数作为划分的,树相关的参数定义在treeParams.scala这个文件中,这里构建一个关于树的体系结构。首先,以DecisionTreeParams作为起始,这里存储了跟树相关的最基础的参数,注意它扩展自PredictorParams。接下来为了区分分类器和回归器,提出了TreeClassifierParams和TreeRegressorParams,两者都直接扩展自Params,分别定义了树相关的分类器和回归器所需要的特殊参数。有了这三个trait,我们可以组合成DecisionTreeClassifierParams和DecisionTreeRegressorParams。然后,为了产生Ensemble的模型,提出了TreeEnsembleParams,专门针对树集合的模型,在此基础上,分别为随机森林和GBT定义了RandomForestParams和GBTParams,可以通过混入分类器和回归器,在这俩的基础上生成随机森林分类器、随机森林回归器、GBT分类器、GBT回归器。
            讲完了参数再讲讲模型,树相关的模型只有两种,DecisionTreeModel和TreeEnsembleModel,比较简单就不深入介绍了。
            下面说一下“树”结构的具体表示,这里涉及到Node.scala文件。树由节点组成,这个文件中定义了四种节点,分别是,Node,这是最基础的节点,是其它三种节点的超类;LeafNode,这是叶子节点,InternalNode,这是中间节点。最后一种是LearningNode,这种节点是专门用来在树学习过程中使用的,其余三类节点都是不可变对象,唯有这个LearningNode是可变对象,可以在学习的过程中被不断迭代,在学习结束之后转换成LeafNode和InternalNode。
            接下来再说说数据样本在树学习中的表示形式。由于树学习的特殊性,在正式用数据进行学习之前,需要对数据进行转换和丰富,Spark ML为了方便树学习,提出了很多数据结构。用于树学习的数据样本,通常可以包含离散和连续两种特征,本质上树学习仅支持离散特征,因此在正式学习前都需要对特征进行离散化。离散化后的样本数据可以表示成一个数组,数组中的元素就是该样本在不同特征的取值。因此,一个原始的LabeledPoint就可以被表示为一个TreePoint,它包含两个数据,Label和binFeature(就是数据特征数组)。这样对于单棵树的学习足够了,但是对于随机森林一类的多树的学习还不行,对于随机森林来说,同样的一个数据样本点可能在森林中的任一一颗树上出现(采样),因此在树学习时,我们需要记录这个数据节点在某棵树上有没有出现过,我们用一个数组来表示一个节点在各树上出现的次数,其中数组长度代表森林中树的数量,数组元素代表数据节点在这棵树上出现的次数,因此就形成了BaggedPoint结构。最后,我们知道随机森林学习的过程,实际上就是不断扩展节点的过程,当前这个数据在每颗树上的节点位置,也需要记录,因此提出NodeInCache数据结构,这同样也是个数组,数组长度是森林中树的个数,数组元素代表当前数据在这棵树上的节点编号。某棵树上的数据刚开始都在根节点,随着学习的进行,数据会逐渐从根节点向下走,直到某个节点因为不能再分解,形成叶节点为止。因此这个数组的内容也是在不断迭代的。关于树的结构,这里再多说一句,如果某个节点的编号是id,那么它的左子节点的编号就是id<<1,右子节点编号就是id<<1+1,因此在一棵树中,某个节点的位置和编号是一一对应的。
            做了这么多铺垫,最后终于说到树学习的过程了,具体算法在RandomForest.scala文件中,这个文件是整个tree文件夹的集大成者,其中包含了随机森林模型学习的核心代码。学习的过程可以分为以下几个步骤,第一,创建metadata,在学习之前,我们要先对数据集有一个大致的了解,比如有多少数据、每个数据多少维度等等,因此需要先从原始数据中总结出metadata。第二,寻找切分点,由于数据特征不可能都是离散型,或者即便是离散型,但因为类别太多,需要进一步处理,因此需要在这里对每个特征寻找切分点,然后把每一个特征值放入特定的切分区间内,形成原始的binFeature向量。第三,转换为BaggedPoint,关于BaggedPoint的含义,之前说的很清楚了,主要是为了方便学习和记录,加入了额外的数据结构。第四,初始化NodeInCache,这个也是为了加速训练过程。第五,初始化NodeStack,我们知道在学习随机森林模型时,需要不断扩展已有的树,把节点拆分为左右子节点,那么已有的树由那么多节点,先拆分哪些节点呢?Spark ML的做法是,把待拆分的节点放入一个栈中,每次根据现有内存的容量,从中拿出若干个点用来拆分,然后把拆分好的点再作为中间节点放入栈中。重复这个过程,直到栈中没有节点为止。因此这里的第五步就是初始化这个节点栈。
            第六,就正式进入随机森林学习的大循环了。这个大循环包含三个步骤,1,master端,在NodeStack中选择节点和特征,在随机森林学习中,每个节点的待选特征集是随机的,不一定就是特征全集,所以这里除了要选节点之外,还要选特征集;2,worker端,计算当前worker上各数据针对选出节点的充分统计量,然后把各worker上的充分统计量在某一个worker上汇总,由这个worker通过算法确定切分点,然后把选好的切分标准返回给master;3,master端,收集各worker返回的切分结果,对模型进行迭代,然后在下一轮循环前,把当前的模型push给各个worker。循环的三个步骤中,计算量最大的就是第二个步骤,既要计算充分统计量,又要把各worker上的统计量汇总,这里的计算和通信压力都很大,是随机森林的性能瓶颈所在。因此在使用随机森林模型时,在数据量允许的情况下,尽量不要把executor数量设置的太高,否则通信的成本会很大。
            为什么没讲决策树直接讲了随机森林呢?因为角色数就是只有一棵树的随机森林嘛,简化一下就得到了。
            以上就是Spark ML中树模型的基础内容,讲的还比较烦,有时间的话再专门出一篇博客,详细讲解随机森林学习中的优化方法。还请大家不吝赐教。

  • 相关阅读:
    范德蒙矩阵相关
    bat运行exe程序
    github 用token远程连接(三)
    为什么将样本方差除以N1?
    Git commit格式 详解(二)
    C++中this与*this的区别
    函数末尾加入const的作用
    git 使用小补充(四)
    人工智能 机器学习
    机器学习分类
  • 原文地址:https://www.cnblogs.com/jicanghai/p/8686392.html
Copyright © 2011-2022 走看看