zoukankan      html  css  js  c++  java
  • 《机器学习实战》第三章:决策树(2)树的构造

    很好。这一篇里面我们来写一些代码。

    决策树可以通过递归的方式来构造。在真正建树之前,我们先来写一些子模块的代码。

    ------------------------------------------------------------------------------------------------

    计算给定数据集的熵

    我们先拿个例子来做数据集吧。就是下面这个海洋生物数据:


    两个特征:(1)不复出水面是否可以生存(英语:no surfacing);(2)是否有脚蹼(英语:flippers)

    一个标签:是否属于鱼类。有2中分类:YES / NO。

    共5条数据。

    导入这个数据集:

    [python] view plain copy
    1. def createDataSet():  
    2.     dataSet = [[1, 1, 'YES'],  
    3.                [1, 1, 'YES'],  
    4.                [1, 0, 'NO'],  
    5.                [0, 1, 'NO'],  
    6.                [0, 1, 'NO']]  
    7.     featNames = ['no surfacing','flippers']  
    8.     return dataSet, featNames  

    其中,featNames中的两项分别是2个特征的名称。

    以下是计算给定数据集dataSet的熵:

    [python] view plain copy
    1. from math import log  
    2.   
    3. def calcShannonEnt(dataSet):  
    4.     numEntries = len(dataSet)  
    5.     labelCounts = {}  
    6.     for featVec in dataSet:  
    7.         currentLabel = featVec[-1]  
    8.         if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0  
    9.         labelCounts[currentLabel] += 1  
    10.     shannonEnt = 0.0  
    11.     for key in labelCounts:  
    12.         prob = float(labelCounts[key])/numEntries  
    13.         shannonEnt -= prob * log(prob, 2)  
    14.     return shannonEnt  

    labelCounts是个字典,对每一种分类,统计出现的次数。比如上面那个例子,就是 {'YES':2 , 'NO':3}

    然后就是按照熵的计算公式来算了。要注意的是python的log函数,它的底数是放在第2个参数位置。

    ------------------------------------------------------------------------------------------------

    划分数据集

    这一小节来看看选取划分特征划分数据集的代码。

    前者待会讲,先看后者的:

     

    [python] view plain copy
    1. def splitDataSet(dataSet, axis, value):  
    2.     retDataSet = []  
    3.     for featVec in dataSet:  
    4.         if featVec[axis] == value:  
    5.             reducedFeatVec = featVec[:axis]  # chop out axis used for splitting  
    6.             reducedFeatVec.extend(featVec[axis + 1:])  
    7.             retDataSet.append(reducedFeatVec)  
    8.     return retDataSet  

    splitDataSet函数用来划分数据集,三个参数:

    dataSet是待划分的数据集,按照下标为axis的特征来划分。

    划分的结果是dataSet数据集中,下标为axis的特征的值为value的数据组成的子数据集。

    要注意的是,获得子数据集,数据是不含下标为axis的特征的,因为已经选过这个特征了,所以要把它剔除掉。

    测试一下:

    [python] view plain copy
    1. dataSet, feats = createDataSet()  
    2. print splitDataSet(dataSet, 0, 1) #按特征0划分,特征值为1  
    3. print splitDataSet(dataSet, 0, 0) #按特征0划分,特征值为0  



    接下来就是选取最佳特征了。

     

    [python] view plain copy
    1. def chooseBestFeatureToSplit(dataSet):  
    2.     numFeatures = len(dataSet[0]) - 1  #每条数据的特征数量  
    3.     baseEntropy = calcShannonEnt(dataSet)  #划分前的熵  
    4.     bestInfoGain = 0.0;  #记录最高信息增益  
    5.     bestFeature = -1  #记录最佳特征  
    6.     for i in range(numFeatures):  #遍历每个特征  
    7.         featList = [data[i] for data in dataSet]  #把所有数据的该特征值抽出来放到一个list里面  
    8.         uniqueVals = set(featList)  #利用set找出该特征所有不同的值  
    9.         newEntropy = 0.0  
    10.         for value in uniqueVals:  #按这些不同的特征值,分别划分成子数据集  
    11.             subDataSet = splitDataSet(dataSet, i, value)  
    12.             prob = len(subDataSet) / float(len(dataSet))  #子数据集的权重  
    13.             newEntropy += prob * calcShannonEnt(subDataSet)  #子数据集的权重*熵  
    14.         infoGain = baseEntropy - newEntropy  #计算信息增益  
    15.         if infoGain > bestInfoGain:  #更新最高信息增益和最佳特征  
    16.             bestInfoGain = infoGain  
    17.             bestFeature = i  
    18.     return bestFeature  #返回最佳特征的下标  

    python的set函数可以把一个list里出现过的不同的值摘取出来,就是去重的作用。

    仍然测试一下:

    [python] view plain copy
    1. dataSet, feats = createDataSet()  
    2. print chooseBestFeatureToSplit(dataSet)  

    结果是0。也就是说最初那5条数据,最佳特征是特征0,也就是“不浮在水面是否可以生存”。得按这个特征划分。

    ------------------------------------------------------------------------------------------------

    递归构建决策树

    构造决策树的大致流程:

    -- 得到原始数据集,选出最佳特征,按照它,划分成多个子数据集,生成多个分支

    -- 对于每个子数据集,又选出最佳特性,互粉多个子数据集,生成多个分支

    所以,我们可以采用递归的原则处理数据集。

    递归的结束条件是:程序遍历完所有可用于划分数据集的属性,或者每个分支下的所有数据都具有相同的分类(即标签)


    这时会出现一个问题:我们用的是ID3算法,每在一个节点进行划分,都会“消耗”掉一个特征。可以这样想,决策树每往下构造一层,能用于划分数据集的特征就少一个。那么,如果到某一个节点,没有特征可用了,而此时这堆数据的标签并不是同一个,怎么办?很简单,投票咯,少数服从多数。

    [python] view plain copy
    1. import operator  
    2. def majorityCnt(classList):  
    3.     classCount={}  
    4.     for vote in classList:  
    5.         if vote not in classCount.keys(): classCount[vote] = 0  
    6.         classCount[vote] += 1  
    7.     sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)  
    8.     return sortedClassCount[0][0]  

    这个majorityCnt函数的作用,就是找出classList这堆标签里,出现次数最多的哪个标签。

    classCount是个字典,记录每个标签出现的次数。

    sorted函数classCount按照键值对的值进行降序排序,返回一个tuple的list。sortedClassCount[0][0]就是出现次数最多的那个标签了。


    现在Boss就可以登场了:创建树!

    [python] view plain copy
    1. def createTree(dataSet, featNames):  
    2.     classList = [data[-1] for data in dataSet]  #当前数据集的所有标签  
    3.     if classList.count(classList[0]) == len(classList):  #如果这堆标签全都一样的话,返回这个标签。  
    4.         return classList[0]  
    5.     if len(dataSet[0]) == 1:  #如果当前数据集一个特征都不剩了,那就不用再划分下去了  
    6.         return majorityCnt(classList)  #直接投票,返回出现次数最多的标签  
    7.     bestFeat = chooseBestFeatureToSplit(dataSet)  #选出用于划分的最佳属性  
    8.     bestFeatName = featNames[bestFeat]  #最佳属性的属性名称  
    9.     myTree = {bestFeatName:{}}  #字典:记录最佳属性对应的标签种类、出现次数情况  
    10.     del(featNames[bestFeat])  #在属性名称列表中剔除最佳属性  
    11.     featValues = [data[bestFeat] for data in dataSet]  #当前数据集中最佳属性的所有属性值  
    12.     uniqueVals = set(featValues) #最佳属性的不同属性值  
    13.     for value in uniqueVals:  
    14.         subfeatNames = featNames[:]  #去除最佳属性后的属性名称列表  
    15.         # 构建最佳属性的值为value的子树  
    16.         myTree[bestFeatName][value] = createTree(splitDataSet(dataSet, bestFeat, value),subfeatNames)  
    17.     return myTree  

    很神奇的地方在于三个return中,前两个返回类型是标签(integer),而第三个返回类型是一棵树(dict)。

    这是因为构建一个叶节点时,我们需要知道它这堆数据对应的是哪个标签;而构建一个内部节点时,我们需要知道划分之后它有哪些子节点。

    另外...哎,你说C++啥的,你怎样让一个函数里不同分支返回不同的数据类型嘛?联合体吗?python大法好!


    现在我们来测试一下,看构造出来的是个什么玩意儿:

    [python] view plain copy
    1. dataSet, feats = createDataSet()  
    2. theTree = createTree(dataSet, feats)  
    3. print theTree  

    所以这是个什么东西?

    画出来就明了了:



    好了,代码就是这些了。用的时候只要在最开始的时候,按格式把自己的数据集导入程序就可以了。

  • 相关阅读:
    Windows Phone 8 开发环境搭建
    常用正则表达式大全分享
    ios 使用NSRegularExpression解析正则表达式
    大整数类BIGN的设计与实现 C++高精度模板
    CODEVS_1227 方格取数2 网络流 最小费用流 拆点
    CODEVS_1034 家园 网络流 最大流
    CODEVS_1033 蚯蚓的游戏问题 网络流 最小费用流 拆点
    HDU_4770 Lights Against Dudely 状压+剪枝
    CODEVS_2144 砝码称重 2 折半搜索+二分查找+哈希
    CODEVS_1074 食物链
  • 原文地址:https://www.cnblogs.com/wyuzl/p/7699897.html
Copyright © 2011-2022 走看看