zoukankan      html  css  js  c++  java
  • 机器学习算法之决策树

    大家都知道二叉树,决策树算法就是利用二叉树的结构,利用数据特征对数据集进行分类,直到所有具有相同类型的数据在一个子数据集内。本文的决策树算法参照《机器学习实战第三章,使用ID3算法划分数据集。如何确定用于划分数据的数据特征呢,使用信息论中的信息熵和信息增益作为划分的度量方法。信息熵的概念源自物理热力学,在热力学中用熵表示分子状态的混乱程度,香农在信息论中用熵的来描述信息源的不确定度,可以通过以下公式定义:

    其中p(xi)为每个特征值的概率,I(xi)表示随机变量的信息,信息的定义是:对于一个事件i,它发生的概率是pi,那么它的信息就是对这个概率取对数的相反数:I(xi)=logbP(xi),其中b为底数,可以取2,10,e.

    要明白信息增益,我们还要直到条件熵,我们都知道条件概率是给定条件下某个事件发生的概率,条件熵就是给定条件下的条件干率分布的熵对X的数学期望,在机器学习中可以理解为选定某个特征后的熵:

    在知道熵、条件熵的概念后,我们就可以得到信息增益:所有分类的熵 - 某个特征值对应的条件熵:

    信息增益越大,就代表信息不确定性减少的程度最大,就是说那一个特征的条件熵对熵的影响很大,那么这个特征值就是最好的特征值。

    以下是具体的代码实现:

    # 决策树算法的代码

    import matplotlib.pyplot as plt
    
    decisionNode = dict(boxstyle='sawtooth',fc="0.8")
    leafNode = dict(boxstyle='round4',fc="0.8")
    arrow_args = dict(arrowstyle="<-")
    
    
    # 在父子节点间填充文本信息
    def plotMidText(cntrPt, parentPt, txtString):
        # 分别计算填充文文本位置的x,y坐标
        xMid = (parentPt[0] - cntrPt[0])/2.0 + cntrPt[0]
        yMid = (parentPt[1] - cntrPt[1])/2.0 + cntrPt[1]
        # createPlot方法的ax1属性为一个plot视图,此处为视图添加文本
        createPlot.ax1.text(xMid,yMid,txtString)
    
    # 计算树的宽和高
    def plotTree(myTree, parentPt, nodeTxt):
        # 获取叶节点数
        numleafs = getNumLeafs(myTree)
        depth = getTreeDepth(myTree)
        # 获取树的第一个key(根节点)
        firstStr = list(myTree.keys())[0]
        # 子节点的坐标计算
        # 子节点 X坐标=节点的x偏移量 + (叶节点数 )
        cntrPt = (plotTree.xOff + (1.0 + float(numleafs))/2.0/plotTree.totalW,plotTree.yOff)
        # 填充父子节点键的文本
        plotMidText(cntrPt, parentPt, nodeTxt)
        # 绘制树节点
        plotNode(firstStr,cntrPt,parentPt,decisionNode)
        # 通过第一个key取获取value
        secondDict = myTree[firstStr]
        # 树的Y坐标偏移量
        plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
        # 对比value(所有节点名称,通过节点名称获取到对应的dict)
        for key in secondDict.keys():
            if type(secondDict[key]).__name__ == 'dict':
                # 如果遍历到字典,将调用本身绘制子节点
                plotTree(secondDict[key],cntrPt,str(key))
            else: # 已经遍历不到字典,此处已经是最后一个,将其画上
                plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
                # 绘制子节点
                plotNode(secondDict[key],(plotTree.xOff, plotTree.yOff),cntrPt, leafNode)
                # 添加节点间的文本信息
                plotMidText((plotTree.xOff, plotTree.yOff),cntrPt, str(key))
        # 确定y的偏移量
        plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
    
    
    # 创建视图
    def createPlot(inTree):
    
        fig = plt.figure(1, facecolor='White')
        fig.clf()
    
        # 不需要设置x,y的刻度文本
        axprops = dict(xticks= [], yticks=[])
        # 添加子图
        createPlot.ax1 = plt.subplot(111,frameon=False, **axprops)
        # 设置plotTree方法中的变量
        # 总的宽度 = 叶子节点的数量
        plotTree.totalW = float(getNumLeafs(inTree))
        # 总的高度 = 树的层数
        plotTree.totalD = float(getTreeDepth(inTree))
        # 定义plotTree的xOff, yOff属性的初始值
        plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0
        # 调用plotTree方法
        plotTree(inTree, (0.5, 1.0), '')
        plt.show()
    
    
    
    # 绘制树节点
    def plotNode(nodeTxt, centerPt, parentPt, nodeType):
        createPlot.ax1.annotate(nodeTxt, xy=parentPt,xycoords='axes fraction',
                                xytext=centerPt, textcoords='axes fraction',
                                va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
    
    
    # def create_plot():
    #     fig = plt.figure(1,facecolor='white')
    #     fig.clf()
    #     create_plot.ax1 = plt.subplot(111, frameon=False)
    #     plotNode('决策节点',(0.5, 0.1), (0.1, 0.5), decisionNode)
    #     plotNode('叶节点',(0.8, 0.1),(0.3, 0.8), leafNode)
    #     plt.show()
    
    
    # 获取叶节点数
    def getNumLeafs(myTree):
        # 初始化叶节点的计数
        numLeafs = 0
        # 从myTree的所有节点获取第一个节点(根节点)
        firstStr = list(myTree.keys())[0]
        # 通过跟节点的key取出根key对应的value
        secondDict = myTree[firstStr]
        # 遍历根key的value(value包含根key包含的余下所有的子节点)
        # 上一级的value包含下一级的key,因此通过递归,可以不断取到下一层的value
        for key in secondDict.keys():
            # 只要获取到的value的是字典的类型,就进行递归,接着往下取叶节点
            if type(secondDict[key]).__name__ == 'dict':
                # 每次递归调用该函数都会获取到该节点下的所有叶节点,并进行计数
                numLeafs += getNumLeafs(secondDict[key])
            # 如果获取的vlaue不再是字典,说明已经是最后一个子节点,进行一次加1操作
            else: numLeafs += 1
        return numLeafs
    
    
    # 获取树的层数
    
    def getTreeDepth(myTree):
        # 树的层数与获取叶节点的步骤相似,区别在于
        # 叶节点数每遍历一次,如果遍历到叶子节点,那么将计数加一,累计叶子节点的个数;
        # 树层数的计数在递归的过程中,如果遍历到叶子节点,就会将计数值置为1,只保留max的计数。
        # 将这一层的深度记为1
    
    
        # 初始化一个记录最大深度的变量
        maxDepth = 0
        firstStr = list(myTree.keys())[0]
        secondDict = myTree[firstStr]
        for key in secondDict.keys():
            if type(secondDict[key]).__name__ == 'dict':
            # 每次递归都进行依次+1的计数操作
                thisDepth = 1 + getTreeDepth(secondDict[key])
            # 如果没有遍历到dict,只有只有一层
            else: thisDepth = 1
            # 每一个key对用的子节点串(每一条路径)都会有一个最大值,记录其中最大的那个
            if thisDepth > maxDepth: maxDepth = thisDepth
        return maxDepth
    
    
    # 输出预先存储的树信息
    def retriveTree(i):
        listOfTrees = [{'no surfacing': {0: 'no', 1:{'flippers':{0:'no',1:'yes'}}}},
                       {'no surfacing': {0: 'no', 1:{'flippers':{0: {'head':{0: 'no', 1:'yes'}}, 1: 'no'}}}}]
    
        return listOfTrees[i]
    

    # 绘制决策树的代码

    import matplotlib.pyplot as plt
    from cha03_trees import trees
    
    decisionNode = dict(boxstyle='sawtooth',fc="0.8")
    leafNode = dict(boxstyle='round4',fc="0.8")
    arrow_args = dict(arrowstyle="<-")
    
    
    # 在父子节点间填充文本信息
    def plotMidText(cntrPt, parentPt, txtString):
        # 分别计算填充文文本位置的x,y坐标
        xMid = (parentPt[0] - cntrPt[0])/2.0 + cntrPt[0]
        yMid = (parentPt[1] - cntrPt[1])/2.0 + cntrPt[1]
        # createPlot方法的ax1属性为一个plot视图,此处为视图添加文本
        createPlot.ax1.text(xMid,yMid,txtString)
    
    # 计算树的宽和高
    def plotTree(myTree, parentPt, nodeTxt):
        # 获取叶节点数
        numleafs = getNumLeafs(myTree)
        depth = getTreeDepth(myTree)
        # 获取树的第一个key(根节点)
        firstStr = list(myTree.keys())[0]
        # 子节点的坐标计算
        # 子节点 X坐标=节点的x偏移量 + (叶节点数 )
        cntrPt = (plotTree.xOff + (1.0 + float(numleafs))/2.0/plotTree.totalW,plotTree.yOff)
        # 填充父子节点键的文本
        plotMidText(cntrPt, parentPt, nodeTxt)
        # 绘制树节点
        plotNode(firstStr,cntrPt,parentPt,decisionNode)
        # 通过第一个key取获取value
        secondDict = myTree[firstStr]
        # 树的Y坐标偏移量
        plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
        # 对比value(所有节点名称,通过节点名称获取到对应的dict)
        for key in secondDict.keys():
            if type(secondDict[key]).__name__ == 'dict':
                # 如果遍历到字典,将调用本身绘制子节点
                plotTree(secondDict[key],cntrPt,str(key))
            else: # 已经遍历不到字典,此处已经是最后一个,将其画上
                plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
                # 绘制子节点
                plotNode(secondDict[key],(plotTree.xOff, plotTree.yOff),cntrPt, leafNode)
                # 添加节点间的文本信息
                plotMidText((plotTree.xOff, plotTree.yOff),cntrPt, str(key))
        # 确定y的偏移量
        plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
    
    
    # 创建视图
    def createPlot(inTree):
    
        fig = plt.figure(1, facecolor='White')
        fig.clf()
    
        # 不需要设置x,y的刻度文本
        axprops = dict(xticks= [], yticks=[])
        # 添加子图
        createPlot.ax1 = plt.subplot(111,frameon=False, **axprops)
        # 设置plotTree方法中的变量
        # 总的宽度 = 叶子节点的数量
        plotTree.totalW = float(getNumLeafs(inTree))
        # 总的高度 = 树的层数
        plotTree.totalD = float(getTreeDepth(inTree))
        # 定义plotTree的xOff, yOff属性的初始值
        plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0
        # 调用plotTree方法
        plotTree(inTree, (0.5, 1.0), '')
        plt.show()
    
    
    
    # 绘制树节点
    def plotNode(nodeTxt, centerPt, parentPt, nodeType):
        createPlot.ax1.annotate(nodeTxt, xy=parentPt,xycoords='axes fraction',
                                xytext=centerPt, textcoords='axes fraction',
                                va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
    
    
    # def create_plot():
    #     fig = plt.figure(1,facecolor='white')
    #     fig.clf()
    #     create_plot.ax1 = plt.subplot(111, frameon=False)
    #     plotNode('决策节点',(0.5, 0.1), (0.1, 0.5), decisionNode)
    #     plotNode('叶节点',(0.8, 0.1),(0.3, 0.8), leafNode)
    #     plt.show()
    
    
    # 获取叶节点数
    def getNumLeafs(myTree):
        # 初始化叶节点的计数
        numLeafs = 0
        # 从myTree的所有节点获取第一个节点(根节点)
        firstStr = list(myTree.keys())[0]
        # 通过跟节点的key取出根key对应的value
        secondDict = myTree[firstStr]
        # 遍历根key的value(value包含根key包含的余下所有的子节点)
        # 上一级的value包含下一级的key,因此通过递归,可以不断取到下一层的value
        for key in secondDict.keys():
            # 只要获取到的value的是字典的类型,就进行递归,接着往下取叶节点
            if type(secondDict[key]).__name__ == 'dict':
                # 每次递归调用该函数都会获取到该节点下的所有叶节点,并进行计数
                numLeafs += getNumLeafs(secondDict[key])
            # 如果获取的vlaue不再是字典,说明已经是最后一个子节点,进行一次加1操作
            else: numLeafs += 1
        return numLeafs
    
    
    # 获取树的层数
    
    def getTreeDepth(myTree):
        # 树的层数与获取叶节点的步骤相似,区别在于
        # 叶节点数每遍历一次,如果遍历到叶子节点,那么将计数加一,累计叶子节点的个数;
        # 树层数的计数在递归的过程中,如果遍历到叶子节点,就会将计数值置为1,只保留max的计数。
        # 将这一层的深度记为1
    
    
        # 初始化一个记录最大深度的变量
        maxDepth = 0
        firstStr = list(myTree.keys())[0]
        secondDict = myTree[firstStr]
        for key in secondDict.keys():
            if type(secondDict[key]).__name__ == 'dict':
            # 每次递归都进行依次+1的计数操作
                thisDepth = 1 + getTreeDepth(secondDict[key])
            # 如果没有遍历到dict,只有只有一层
            else: thisDepth = 1
            # 每一个key对用的子节点串(每一条路径)都会有一个最大值,记录其中最大的那个
            if thisDepth > maxDepth: maxDepth = thisDepth
        return maxDepth
    
    
    # 输出预先存储的树信息
    def retriveTree(i):
        listOfTrees = [{'no surfacing': {0: 'no', 1:{'flippers':{0:'no',1:'yes'}}}},
                       {'no surfacing': {0: 'no', 1:{'flippers':{0: {'head':{0: 'no', 1:'yes'}}, 1: 'no'}}}}]
    
        return listOfTrees[i]
    
    
    if __name__ == '__main__':
        # myTree = retriveTree(0)
        # createPlot(myTree)
        fr = open("../cha03_trees/lenses.txt")
        lenses = [inst.strip().split('	') for inst in fr.readlines()]
        lensLabels = ['age', 'prescipt','astigmatic','tearRate']
        lensesTree = trees.createTree(lenses, lensLabels)
        createPlot(lensesTree)

    运行结果:

    代码地址:https://github.com/ZhaoJiangJie/MLInAction/tree/master/cha03_trees

    参考:1.《机器学习实战》peter Harrington 著

             2.https://www.cnblogs.com/fantasy01/p/4581803.html

             3.https://www.zhihu.com/question/22104055

             4.http://blog.csdn.net/aws3217150/article/details/49906389

  • 相关阅读:
    vue之插槽
    微信公众号-关注和取消关注
    微信公众号-消息响应
    微信公众号-验证接入
    微信公众号-开发工具natapp内网穿透安装和使用
    windows安装PHP5.4+Apache2.4+Mysql5.5
    php各种主流框架的优缺点总结
    php框架的特性总结
    什么是php?php的优缺点有哪些?与其它编程语言的优缺点?
    二进制、八进制、十进制、十六进制之间转换
  • 原文地址:https://www.cnblogs.com/thsk/p/8465666.html
Copyright © 2011-2022 走看看