zoukankan      html  css  js  c++  java
  • 《机器学习实战》笔记——决策树(ID3)

    闲来无事最近复习了一下ID3决策树算法,并凭着理解用pandas实现了一遍。对pandas更熟悉的朋友可供参考(链接如下)。相比本篇博文,更简明清晰,更适合复习用。

    https://github.com/DianeSoHungry/ShallowMachineLearningCodeItOut/blob/master/ID3.ipynb

     

     

     

    现在要介绍的是ID3决策树算法,只适用于标称型数据,不适用于数值型数据。

     

    决策树学习算法最大的优点是,他可以自学习,在学习过程中,不需要使用者了解过多的背景知识、领域知识,只需要对训练实例进行较好的标注就可以自学习了。

     

    建立决策树的关键在于当前状态下选择哪一个属性作为分类依据,根据不同的目标函数,有三种主要的算法

    ID3(Iterative Dichotomiser)

    C4.5

    CART(Classification And Regression Tree)

    问题描述:

    下面是一个小型的数据集,5条记录,2个特征(属性),有标签。

    根据这个数据集,我们可以建立如下决策树(用matplotlib的注释功能画的)。

    观察决策树,决策节点为特征,其分支为决策节点的各个不同取值,叶节点为预测值。

    建树结束也就是建立好了一个决策树分类器,有了分类器,就可以根据这个分类器对其他的鱼进行预测了。预测准确性今天暂且不讨论。

    那么如何建立这样的决策树呢?

    第一步:建立决策树。

    1.1 利用信息增益寻找当前最佳分类特征

    想象现在你是一个判断结点,你从头顶的分支上获得了一个数据集,表中包含标签和若干属性。你现在要根据某个属性来对你接收到的数据集进行分组。到底用哪个属性来作为划分依据呢?

    我们用信息增益来选择某个节点上用哪个特征来进行分类。

    什么是信息?

    如果待分类的事物可能划分在多个分类中,则每个分类xi的信息定义为:

    (这里log前面应该有个负号。)

    什么是香农熵?

    香农熵是所有类别所有可能类别信息的期望值,即:

    什么是信息增益?

    信息增益=原香农熵-新香农熵

    注意:新香农熵为按照某特征划分之后,每个分支数据集的香农熵之和。

      

    可以这样想:香农熵相当于数据类别(标签)的混乱程度,信息增益可以衡量划分数据集前后数据(标签)向有序性发展的程度。因此,回到怎样利用信息增益寻找当前最佳分类特征的话题,假如你是一个判断节点,你拿来一个数据集,数据集里面有若干个特征,你需要从中选取一个特征,使得信息增益最大(注意:将数据集中在该特征上取值相同的记录划分到同一个分支,得到若干个分支数据集,每个分支数据集都有自己的香农熵,各个分支数据集的香农熵的期望才是新香农熵)。要找到这个特征只需要将数据集中的每个特征遍历一次,求信息增益,取获得最大信息增益的那个特征。

    代码如下(其中,calcShannonEnt(dataSet)函数用来计算数据集dataSet的香农熵,splitDataSet(dataSet, axis, value)函数将数据集dataSet的第axis列中特征值为value的记录挑出来,组成分支数据集返回给函数。这两个函数后面会给出函数定义。):

     1 # 3-3 选择最好的'数据集划分方式'(特征)
     2 # 一个一个地试每个特征,如果某个按照某个特征分类得到的信息增益(原香农熵-新香农熵)最大,
     3 # 则选这个特征作为最佳数据集划分方式
     4 def chooseBestFeatureToSplit(dataSet):
     5     numFeatures = len(dataSet[0]) - 1
     6     baseEntropy = calcShannonEnt(dataSet)
     7     bestInfoGain = 0.0
     8     bestFeature = -1
     9     for i in range(numFeatures):
    10         featList = [example[i] for example in dataSet]
    11         uniqueVals = set(featList)
    12         newEntropy = 0.0
    13         for value in uniqueVals:
    14             subDataSet = splitDataSet(dataSet, i, value)
    15             prob = len(subDataSet) / float(len(dataSet))
    16             newEntropy += prob * calcShannonEnt(subDataSet)
    17         infoGain = baseEntropy - newEntropy
    18         if (infoGain > bestInfoGain):
    19             bestInfoGain = infoGain
    20             bestFeature = i
    21     return bestFeature

    calcShannonEnt(dataSet)函数代码:

     1 def calcShannonEnt(dataSet):
     2     numEntries = len(dataSet)    # 总记录数
     3     labelCounts = {}    # dataSet中所有出现过的标签值为键,相应标签值出现过的次数作为值
     4     for featVec in dataSet:
     5         currentLabel = featVec[-1]
     6         labelCounts[currentLabel] = labelCounts.get(currentLabel, 0) + 1
     7     shannonEnt = 0.0
     8     for key in labelCounts:
     9         prob = -float(labelCounts[key])/numEntries
    10         shannonEnt += prob * log(prob, 2)
    11     return shannonEnt

    splitDataSet(dataSet, axis, value)函数代码:

     1 # 3-2 按照给定特征划分数据集(在某个特征axis上,值等于value的所有记录
     2 # 组成新的数据集retDataSet,新数据集不需要axis这个特征,注意value是特征值,axis指的是特征(所在的列下标))
     3 def splitDataSet(dataSet, axis, value):
     4     retDataSet = []
     5     for featVec in dataSet:
     6         if featVec[axis] == value:
     7             reducedFeatVec = featVec[:axis]
     8             reducedFeatVec.extend(featVec[axis+1:])
     9             retDataSet.append(reducedFeatVec)
    10     return retDataSet

    1.2 建树

    建树是一个递归的过程。

    递归结束的标志(判断某节点是叶节点的标志):

    情况1. 分到该节点的数据集中,所有记录的标签列取值都一样。

    情况2. 分到该节点的数据集中,只剩下标签列。

    a. 经判断,若是叶节点,则:

    对应情况1,返回数据集中第一条记录的标签值(反正所有标签值都一样)。

    对应情况2,返回数据集中所有标签值中,出现次数最多的那个标签值(代码中,定义一个函数majorityCnt(classList)来实现)

    b. 经判断,若不是叶节点,则:

    step1. 建立一个字典,字典的键为该数据集上选出的最佳特征(划分依据)。

    step2. 将具有相同特征值的记录组成新的数据集(利用splitDataSet(dataSet, axis, value)函数实现,注意期间抛弃了当前用于划分数据的特征列),对新的数据集们进行递归建树。

    建树代码:

     1 # 3-4 创建树的函数代码
     2 # 如果非叶子结点,则以当前数据集建树,并返回该树。该树的根节点是一个字典,键为划分当前数据集的最佳特征,值为按照键值划分后各个数据集构造的树
     3 # 叶子节点有两种:1.只剩没有特征时,叶子节点的返回值为所有记录中,出现次数最多的那个标签值 2.该叶子节点中,所有记录的标签相同。
     4 
     5 def createTree(dataSet, labels): #label向量的维度为特征数,不是记录数,是不同列下标对应的特征
     6     classList = [example[-1] for example in dataSet]
     7     if classList.count(classList[0]) == len(classList):
     8         return classList[0]
     9     if len(dataSet[0]) == 1:
    10         return majorityCnt(classList)
    11     bestFeat = chooseBestFeatureToSplit(dataSet)
    12     bestFeatLabel = labels[bestFeat]
    13     myTree = {bestFeatLabel: {}}
    14     del(labels[bestFeat])
    15     featValues = [example[bestFeat] for example in dataSet]
    16     uniqueVals = set(featValues)
    17     for value in uniqueVals:  #递归建子树,若值为字典,则非叶节点,若为字符串,则为叶节点
    18         myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), labels)
    19     return myTree

    用上面给出的数据来建立一颗决策树做示范:

    在同一个程序中输入如下代码并运行:

     1 def createDataSet():
     2     dataSet = [[1, 1, 'yes'],
     3                [1, 1, 'yes'],
     4                [1, 0, 'no'],
     5                [0, 1, 'no'],
     6                [0, 1, 'no']]
     7     labels = ['no surfacing', 'flippers']
     8     return dataSet, labels
     9 
    10 myDat, labels = createDataSet()
    11 myTree = createTree(myDat, labels)
    12 print myTree

    运行结果为:

    {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
    

     若利用后面画决策树的代码可以画出这颗决策树:

    案例:

    我们通过建立决策树来预测患者需要佩戴哪种隐形眼镜(soft(软材质)、hard(硬材质)、no lenses(不适合硬性眼睛)),数据集包含下面几个特征:age(年龄), prescript(近视还是远视), astigmatic(散光), tearRate(眼泪清除率)

    建树的结果为:

    {'tearRate': {'reduced': 'no lenses', 'normal': {'astigmatic': {'yes': {'prescript': {'hyper': {'age': {'pre': 'no lenses', 'presbyopic': 'no lenses', 'young': 'hard'}}, 'myope': 'hard'}}, 'no': {'age': {'pre': 'soft', 'presbyopic': {'prescript': {'hyper': 'soft', 'myope': 'no lenses'}}, 'young': 'soft'}}}}}}
    

    画出来是这个样子:

    画决策树的代码(不讲)

    涉及matplotlib.pyplot模块中的annotation的用法,点击链接进入官网学习这块内容的prerequisite。

     1 # _*_coding:utf-8_*_
     2 
     3 # 3-7 plotTree函数
     4 import matplotlib.pyplot as plt
     5 
     6 # 定义节点和箭头格式的常量
     7 decisionNode = dict(boxstyle="sawtooth", fc="0.8")
     8 leafNode = dict(boxstyle="round4", fc="0.8")
     9 arrow_args = dict(arrowstyle="<-")
    10 
    11 
    12 def plotMidTest(cntrPt, parentPt,txtString):
    13     xMid = (parentPt[0] + cntrPt[0])/2.0
    14     yMid = (parentPt[1] + cntrPt[1])/2.0
    15     createPlot.ax1.text(xMid, yMid, txtString)
    16 
    17 # 绘制自身
    18 # 若当前子节点不是叶子节点,递归
    19 # 若当子节点为叶子节点,绘制该节点
    20 def plotTree(myTree, parentPt, nodeTxt):
    21     numLeafs = getNumLeafs(myTree)
    22     # depth = getTreeDepth(myTree)
    23     firstStr = myTree.keys()[0]
    24     cntrPt = (plotTree.xoff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yoff)
    25     plotMidTest(cntrPt, parentPt, nodeTxt)
    26     plotNode(firstStr, cntrPt, parentPt, decisionNode)
    27     secondDict = myTree[firstStr]
    28     plotTree.yoff = plotTree.yoff - 1.0/plotTree.totalD
    29     for key in secondDict.keys():
    30         if type(secondDict[key]).__name__=='dict':
    31             plotTree(secondDict[key], cntrPt, str(key))
    32         else:
    33             plotTree.xoff = plotTree.xoff + 1.0/plotTree.totalW
    34             plotNode(secondDict[key], (plotTree.xoff, plotTree.yoff), cntrPt, leafNode)
    35             plotMidTest((plotTree.xoff, plotTree.yoff), cntrPt, str(key))
    36     plotTree.yoff = plotTree.yoff + 1.0/plotTree.totalD
    37 
    38 
    39 # figure points
    40 # 画结点的模板
    41 def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    42     createPlot.ax1.annotate(nodeTxt,  # 注释的文字,(一个字符串)
    43                             xy=parentPt,  # 被注释的地方(一个坐标)
    44                             xycoords='axes fraction',  # xy所用的坐标系
    45                             xytext=centerPt,  # 插入文本的地方(一个坐标)
    46                             textcoords='axes fraction', # xytext所用的坐标系
    47                             va="center",
    48                             ha="center",
    49                             bbox=nodeType,  # 注释文字用的框的格式
    50                             arrowprops=arrow_args)  # 箭头属性
    51 
    52 
    53 def createPlot(inTree):
    54     fig = plt.figure(1, facecolor='white')
    55     fig.clf()
    56     axprops = dict(xticks=[], yticks=[])
    57     createPlot.ax1 = plt.subplot(111,frameon=False, **axprops)
    58     plotTree.totalW = float(getNumLeafs(inTree))
    59     plotTree.totalD = float(getTreeDepth(inTree))
    60     plotTree.xoff = -0.5/plotTree.totalW
    61     plotTree.yoff = 1.0
    62 
    63     plotTree(inTree, (0.5, 1.0),'') #树的引用作为父节点,但不画出来,所以用''
    64     plt.show()
    65 
    66 def getNumLeafs(myTree):
    67     numLeafs = 0
    68     firstStr = myTree.keys()[0]
    69     secondDict = myTree[firstStr]
    70     for key in secondDict.keys():
    71         if type(secondDict[key]).__name__ =='dict':
    72             numLeafs += getNumLeafs(secondDict[key])
    73         else:
    74             numLeafs += 1
    75     return numLeafs
    76 
    77 # 子树中树高最大的那一颗的高度+1作为当前数的高度
    78 def getTreeDepth(myTree):
    79     maxDepth = 0    #用来记录最高子树的高度+1
    80     firstStr = myTree.keys()[0]
    81     secondDict = myTree[firstStr]
    82     for key in secondDict.keys():
    83         if type(secondDict[key]).__name__ == 'dict':
    84             thisDepth = 1 + getTreeDepth(secondDict[key])
    85         else:
    86             thisDepth = 1
    87         if(thisDepth > maxDepth):
    88             maxDepth = thisDepth
    89     return maxDepth
    90 
    91 # 方便测试用的人造测试树
    92 def retrieveTree(i):
    93     listofTrees = [{'no surfacing':{0:'no',1:{'flippers':{0:'no',1:'yes'}}}},
    94                    {'no surfacing':{0:'no',1:{'flippers':{0:{'head':{0:'no',1:'yes'}},1:'no'}}}}
    95                    ]
    96     return listofTrees[i]

     

     

  • 相关阅读:
    docker 介绍,安装,镜像操作, docker换源
    go语言5 接口, 并发与并行, go协程, 信道, 缓冲信道, 异常处理, python进程线程
    [编织消息框架]目录
    2017总结
    赚钱方法[信息红利]
    面单 全单 单板 批发吉他民谣 知乎 百度知道 百度贴吧 吉他批发
    看第三部杀破狼感想
    海豚极货店 淘宝店开张啦
    我上头条了
    尤克里里 ukulele 单板 非kaka tom uma
  • 原文地址:https://www.cnblogs.com/DianeSoHungry/p/7059104.html
Copyright © 2011-2022 走看看