zoukankan      html  css  js  c++  java
  • [置顶] 决策树绘图(二)

    由于最近在看机器学习实战,所以自己利用python3去完成里面的代码,此代码衔接着http://blog.csdn.net/xueyunf/article/details/9223865

    在这个基础上进行修改完成了这篇文章的代码,我们知道了决策树的简单构建,ID3算法完成,当然这都很基础,画图呢,只是为了让其更加形象化;我们添加几个函数,一个是输出一棵我们可以利用ID3算法生成的树,一个获取树的叶子节点,一个获取树的深度,这些我想这里就不用讲解了,学过数据结构的童鞋,可以在非常短的时间内实现这些算法;当然我先把这3个函数的代码贴出来:

    def  getNumLeafs(myTree):
        numLeafs = 0
        firstStr = list(myTree.keys())[0]
        secondDict  =  myTree[firstStr]
        for key in secondDict.keys():
            if type(secondDict[key]).__name__=='dict':
                numLeafs += getNumLeafs(secondDict[key])
            else:
                numLeafs += 1
        return numLeafs
        
    def getTreeDepth(myTree):
        maxDepth = 0
        firstStr = list(myTree.keys())[0]
        secondDict = myTree[firstStr]
        for key in secondDict.keys():
            if type(secondDict[key]).__name__=='dict':
                thisDepth = 1 + getTreeDepth(secondDict[key])
            else:
                thisDepth = 1
            if thisDepth>maxDepth:
                maxDepth = thisDepth
        return maxDepth
    
    def retrieveTree(i):
        listOfTrees = [{'no surfacing':{0:'no', 1:{'flippers':
                        {0:'no', 1:'yes'}}}},
                       {'no surfacing':{0:'no', 1:{'flippers':
                          {0:{'head':{0:'no', 1:'yes'}}, 1:'no'}}}}
                    ] 
                    
        return listOfTrees[i]


    我们不难看出根据定义第一个函数完成获取所有叶子节点的个数,第二个函数完成获取树的高度,第三个函数完成输出树。

    然后我们放出这次的主要函数,修改后的绘图函数:

    def plotMidText(cntrPt, parentPt, txtString):
        xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
        yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
        createPlot.ax1.text(xMid, yMid, txtString)
    
    def plotTree(myTree, parentPt, nodeTxt):
        numLeafs = getNumLeafs(myTree)
        getTreeDepth(myTree)
        firstStr = list(myTree.keys())[0]
        cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW,
        plotTree.yOff)
        plotMidText(cntrPt, parentPt, nodeTxt)
        plotNode(firstStr, cntrPt, parentPt, decisionNode)
        secondDict = myTree[firstStr]
        plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
        for key in secondDict.keys():
            if type(secondDict[key]).__name__=='dict':
                plotTree(secondDict[key],cntrPt,str(key))
            else:
                plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
                plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff),
                cntrPt, leafNode)
                plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
        plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
    
    def createPlot(inTree):
        fig = plt.figure(1, facecolor='white')
        fig.clf()
        axprops = dict(xticks=[], yticks=[])
        createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
        plotTree.totalW = float(getNumLeafs(inTree))
        plotTree.totalD = float(getTreeDepth(inTree))
        plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;
        plotTree(inTree, (0.5,1.0), '')
        plt.show()
    


    最后当然也是截个图给大家看看程序的运行情况:


    好了,这里面的函数我想大家可以通过名字也知道每个函数干了些什么。

  • 相关阅读:
    C++ STL 一般总结(转载)
    Python高级语法总结
    Adaboost 算法的原理与推导——转载及修改完善
    简化版SMO算法标注
    【转载】机器学习算法基础概念学习总结
    C++中添加配置文件读写方法
    Python中scatter()函数--转载
    python 之 strip()--(转载)
    zabbix邮件报警脚本(Python)
    Linux常用命令
  • 原文地址:https://www.cnblogs.com/jiangu66/p/3170473.html
Copyright © 2011-2022 走看看