zoukankan      html  css  js  c++  java
  • 《机器学习实战》笔记——决策树(ID3)

    闲来无事最近复习了一下ID3决策树算法,并凭着理解用pandas实现了一遍。对pandas更熟悉的朋友可供参考(链接如下)。相比本篇博文,更简明清晰,更适合复习用。

    https://github.com/DianeSoHungry/ShallowMachineLearningCodeItOut/blob/master/ID3.ipynb

     

     

     

    现在要介绍的是ID3决策树算法,只适用于标称型数据,不适用于数值型数据。

     

    决策树学习算法最大的优点是,他可以自学习,在学习过程中,不需要使用者了解过多的背景知识、领域知识,只需要对训练实例进行较好的标注就可以自学习了。

     

    建立决策树的关键在于当前状态下选择哪一个属性作为分类依据,根据不同的目标函数,有三种主要的算法

    ID3(Iterative Dichotomiser)

    C4.5

    CART(Classification And Regression Tree)

    问题描述:

    下面是一个小型的数据集,5条记录,2个特征(属性),有标签。

    根据这个数据集,我们可以建立如下决策树(用matplotlib的注释功能画的)。

    观察决策树,决策节点为特征,其分支为决策节点的各个不同取值,叶节点为预测值。

    建树结束也就是建立好了一个决策树分类器,有了分类器,就可以根据这个分类器对其他的鱼进行预测了。预测准确性今天暂且不讨论。

    那么如何建立这样的决策树呢?

    第一步:建立决策树。

    1.1 利用信息增益寻找当前最佳分类特征

    想象现在你是一个判断结点,你从头顶的分支上获得了一个数据集,表中包含标签和若干属性。你现在要根据某个属性来对你接收到的数据集进行分组。到底用哪个属性来作为划分依据呢?

    我们用信息增益来选择某个节点上用哪个特征来进行分类。

    什么是信息?

    如果待分类的事物可能划分在多个分类中,则每个分类xi的信息定义为:

    (这里log前面应该有个负号。)

    什么是香农熵?

    香农熵是所有类别所有可能类别信息的期望值,即:

    什么是信息增益?

    信息增益=原香农熵-新香农熵

    注意:新香农熵为按照某特征划分之后,每个分支数据集的香农熵之和。

      

    可以这样想:香农熵相当于数据类别(标签)的混乱程度,信息增益可以衡量划分数据集前后数据(标签)向有序性发展的程度。因此,回到怎样利用信息增益寻找当前最佳分类特征的话题,假如你是一个判断节点,你拿来一个数据集,数据集里面有若干个特征,你需要从中选取一个特征,使得信息增益最大(注意:将数据集中在该特征上取值相同的记录划分到同一个分支,得到若干个分支数据集,每个分支数据集都有自己的香农熵,各个分支数据集的香农熵的期望才是新香农熵)。要找到这个特征只需要将数据集中的每个特征遍历一次,求信息增益,取获得最大信息增益的那个特征。

    代码如下(其中,calcShannonEnt(dataSet)函数用来计算数据集dataSet的香农熵,splitDataSet(dataSet, axis, value)函数将数据集dataSet的第axis列中特征值为value的记录挑出来,组成分支数据集返回给函数。这两个函数后面会给出函数定义。):

     1 # 3-3 选择最好的'数据集划分方式'(特征)
     2 # 一个一个地试每个特征,如果某个按照某个特征分类得到的信息增益(原香农熵-新香农熵)最大,
     3 # 则选这个特征作为最佳数据集划分方式
     4 def chooseBestFeatureToSplit(dataSet):
     5     numFeatures = len(dataSet[0]) - 1
     6     baseEntropy = calcShannonEnt(dataSet)
     7     bestInfoGain = 0.0
     8     bestFeature = -1
     9     for i in range(numFeatures):
    10         featList = [example[i] for example in dataSet]
    11         uniqueVals = set(featList)
    12         newEntropy = 0.0
    13         for value in uniqueVals:
    14             subDataSet = splitDataSet(dataSet, i, value)
    15             prob = len(subDataSet) / float(len(dataSet))
    16             newEntropy += prob * calcShannonEnt(subDataSet)
    17         infoGain = baseEntropy - newEntropy
    18         if (infoGain > bestInfoGain):
    19             bestInfoGain = infoGain
    20             bestFeature = i
    21     return bestFeature

    calcShannonEnt(dataSet)函数代码:

     1 def calcShannonEnt(dataSet):
     2     numEntries = len(dataSet)    # 总记录数
     3     labelCounts = {}    # dataSet中所有出现过的标签值为键,相应标签值出现过的次数作为值
     4     for featVec in dataSet:
     5         currentLabel = featVec[-1]
     6         labelCounts[currentLabel] = labelCounts.get(currentLabel, 0) + 1
     7     shannonEnt = 0.0
     8     for key in labelCounts:
     9         prob = -float(labelCounts[key])/numEntries
    10         shannonEnt += prob * log(prob, 2)
    11     return shannonEnt

    splitDataSet(dataSet, axis, value)函数代码:

     1 # 3-2 按照给定特征划分数据集(在某个特征axis上,值等于value的所有记录
     2 # 组成新的数据集retDataSet,新数据集不需要axis这个特征,注意value是特征值,axis指的是特征(所在的列下标))
     3 def splitDataSet(dataSet, axis, value):
     4     retDataSet = []
     5     for featVec in dataSet:
     6         if featVec[axis] == value:
     7             reducedFeatVec = featVec[:axis]
     8             reducedFeatVec.extend(featVec[axis+1:])
     9             retDataSet.append(reducedFeatVec)
    10     return retDataSet

    1.2 建树

    建树是一个递归的过程。

    递归结束的标志(判断某节点是叶节点的标志):

    情况1. 分到该节点的数据集中,所有记录的标签列取值都一样。

    情况2. 分到该节点的数据集中,只剩下标签列。

    a. 经判断,若是叶节点,则:

    对应情况1,返回数据集中第一条记录的标签值(反正所有标签值都一样)。

    对应情况2,返回数据集中所有标签值中,出现次数最多的那个标签值(代码中,定义一个函数majorityCnt(classList)来实现)

    b. 经判断,若不是叶节点,则:

    step1. 建立一个字典,字典的键为该数据集上选出的最佳特征(划分依据)。

    step2. 将具有相同特征值的记录组成新的数据集(利用splitDataSet(dataSet, axis, value)函数实现,注意期间抛弃了当前用于划分数据的特征列),对新的数据集们进行递归建树。

    建树代码:

     1 # 3-4 创建树的函数代码
     2 # 如果非叶子结点,则以当前数据集建树,并返回该树。该树的根节点是一个字典,键为划分当前数据集的最佳特征,值为按照键值划分后各个数据集构造的树
     3 # 叶子节点有两种:1.只剩没有特征时,叶子节点的返回值为所有记录中,出现次数最多的那个标签值 2.该叶子节点中,所有记录的标签相同。
     4 
     5 def createTree(dataSet, labels): #label向量的维度为特征数,不是记录数,是不同列下标对应的特征
     6     classList = [example[-1] for example in dataSet]
     7     if classList.count(classList[0]) == len(classList):
     8         return classList[0]
     9     if len(dataSet[0]) == 1:
    10         return majorityCnt(classList)
    11     bestFeat = chooseBestFeatureToSplit(dataSet)
    12     bestFeatLabel = labels[bestFeat]
    13     myTree = {bestFeatLabel: {}}
    14     del(labels[bestFeat])
    15     featValues = [example[bestFeat] for example in dataSet]
    16     uniqueVals = set(featValues)
    17     for value in uniqueVals:  #递归建子树,若值为字典,则非叶节点,若为字符串,则为叶节点
    18         myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), labels)
    19     return myTree

    用上面给出的数据来建立一颗决策树做示范:

    在同一个程序中输入如下代码并运行:

     1 def createDataSet():
     2     dataSet = [[1, 1, 'yes'],
     3                [1, 1, 'yes'],
     4                [1, 0, 'no'],
     5                [0, 1, 'no'],
     6                [0, 1, 'no']]
     7     labels = ['no surfacing', 'flippers']
     8     return dataSet, labels
     9 
    10 myDat, labels = createDataSet()
    11 myTree = createTree(myDat, labels)
    12 print myTree

    运行结果为:

    {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
    

     若利用后面画决策树的代码可以画出这颗决策树:

    案例:

    我们通过建立决策树来预测患者需要佩戴哪种隐形眼镜(soft(软材质)、hard(硬材质)、no lenses(不适合硬性眼睛)),数据集包含下面几个特征:age(年龄), prescript(近视还是远视), astigmatic(散光), tearRate(眼泪清除率)

    建树的结果为:

    {'tearRate': {'reduced': 'no lenses', 'normal': {'astigmatic': {'yes': {'prescript': {'hyper': {'age': {'pre': 'no lenses', 'presbyopic': 'no lenses', 'young': 'hard'}}, 'myope': 'hard'}}, 'no': {'age': {'pre': 'soft', 'presbyopic': {'prescript': {'hyper': 'soft', 'myope': 'no lenses'}}, 'young': 'soft'}}}}}}
    

    画出来是这个样子:

    画决策树的代码(不讲)

    涉及matplotlib.pyplot模块中的annotation的用法,点击链接进入官网学习这块内容的prerequisite。

     1 # _*_coding:utf-8_*_
     2 
     3 # 3-7 plotTree函数
     4 import matplotlib.pyplot as plt
     5 
     6 # 定义节点和箭头格式的常量
     7 decisionNode = dict(boxstyle="sawtooth", fc="0.8")
     8 leafNode = dict(boxstyle="round4", fc="0.8")
     9 arrow_args = dict(arrowstyle="<-")
    10 
    11 
    12 def plotMidTest(cntrPt, parentPt,txtString):
    13     xMid = (parentPt[0] + cntrPt[0])/2.0
    14     yMid = (parentPt[1] + cntrPt[1])/2.0
    15     createPlot.ax1.text(xMid, yMid, txtString)
    16 
    17 # 绘制自身
    18 # 若当前子节点不是叶子节点,递归
    19 # 若当子节点为叶子节点,绘制该节点
    20 def plotTree(myTree, parentPt, nodeTxt):
    21     numLeafs = getNumLeafs(myTree)
    22     # depth = getTreeDepth(myTree)
    23     firstStr = myTree.keys()[0]
    24     cntrPt = (plotTree.xoff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yoff)
    25     plotMidTest(cntrPt, parentPt, nodeTxt)
    26     plotNode(firstStr, cntrPt, parentPt, decisionNode)
    27     secondDict = myTree[firstStr]
    28     plotTree.yoff = plotTree.yoff - 1.0/plotTree.totalD
    29     for key in secondDict.keys():
    30         if type(secondDict[key]).__name__=='dict':
    31             plotTree(secondDict[key], cntrPt, str(key))
    32         else:
    33             plotTree.xoff = plotTree.xoff + 1.0/plotTree.totalW
    34             plotNode(secondDict[key], (plotTree.xoff, plotTree.yoff), cntrPt, leafNode)
    35             plotMidTest((plotTree.xoff, plotTree.yoff), cntrPt, str(key))
    36     plotTree.yoff = plotTree.yoff + 1.0/plotTree.totalD
    37 
    38 
    39 # figure points
    40 # 画结点的模板
    41 def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    42     createPlot.ax1.annotate(nodeTxt,  # 注释的文字,(一个字符串)
    43                             xy=parentPt,  # 被注释的地方(一个坐标)
    44                             xycoords='axes fraction',  # xy所用的坐标系
    45                             xytext=centerPt,  # 插入文本的地方(一个坐标)
    46                             textcoords='axes fraction', # xytext所用的坐标系
    47                             va="center",
    48                             ha="center",
    49                             bbox=nodeType,  # 注释文字用的框的格式
    50                             arrowprops=arrow_args)  # 箭头属性
    51 
    52 
    53 def createPlot(inTree):
    54     fig = plt.figure(1, facecolor='white')
    55     fig.clf()
    56     axprops = dict(xticks=[], yticks=[])
    57     createPlot.ax1 = plt.subplot(111,frameon=False, **axprops)
    58     plotTree.totalW = float(getNumLeafs(inTree))
    59     plotTree.totalD = float(getTreeDepth(inTree))
    60     plotTree.xoff = -0.5/plotTree.totalW
    61     plotTree.yoff = 1.0
    62 
    63     plotTree(inTree, (0.5, 1.0),'') #树的引用作为父节点,但不画出来,所以用''
    64     plt.show()
    65 
    66 def getNumLeafs(myTree):
    67     numLeafs = 0
    68     firstStr = myTree.keys()[0]
    69     secondDict = myTree[firstStr]
    70     for key in secondDict.keys():
    71         if type(secondDict[key]).__name__ =='dict':
    72             numLeafs += getNumLeafs(secondDict[key])
    73         else:
    74             numLeafs += 1
    75     return numLeafs
    76 
    77 # 子树中树高最大的那一颗的高度+1作为当前数的高度
    78 def getTreeDepth(myTree):
    79     maxDepth = 0    #用来记录最高子树的高度+1
    80     firstStr = myTree.keys()[0]
    81     secondDict = myTree[firstStr]
    82     for key in secondDict.keys():
    83         if type(secondDict[key]).__name__ == 'dict':
    84             thisDepth = 1 + getTreeDepth(secondDict[key])
    85         else:
    86             thisDepth = 1
    87         if(thisDepth > maxDepth):
    88             maxDepth = thisDepth
    89     return maxDepth
    90 
    91 # 方便测试用的人造测试树
    92 def retrieveTree(i):
    93     listofTrees = [{'no surfacing':{0:'no',1:{'flippers':{0:'no',1:'yes'}}}},
    94                    {'no surfacing':{0:'no',1:{'flippers':{0:{'head':{0:'no',1:'yes'}},1:'no'}}}}
    95                    ]
    96     return listofTrees[i]

     

     

  • 相关阅读:
    《Machine Learning in Action》—— 白话贝叶斯,“恰瓜群众”应该恰好瓜还是恰坏瓜
    《Machine Learning in Action》—— 女同学问Taoye,KNN应该怎么玩才能通关
    《Machine Learning in Action》—— Taoye给你讲讲决策树到底是支什么“鬼”
    深度学习炼丹术 —— Taoye不讲码德,又水文了,居然写感知器这么简单的内容
    《Machine Learning in Action》—— 浅谈线性回归的那些事
    《Machine Learning in Action》—— 懂的都懂,不懂的也能懂。非线性支持向量机
    《Machine Learning in Action》—— hao朋友,快来玩啊,决策树呦
    《Machine Learning in Action》—— 剖析支持向量机,优化SMO
    《Machine Learning in Action》—— 剖析支持向量机,单手狂撕线性SVM
    JVM 字节码指令
  • 原文地址:https://www.cnblogs.com/DianeSoHungry/p/7059104.html
Copyright © 2011-2022 走看看