zoukankan      html  css  js  c++  java
  • 【机器学习】决策树-02

    心得体会:

      1。使用字典树和matplotlib绘图

      2.决策树可以用二进制方法‘wb+’存储到文本文件,用‘rb+’从文本文件提取

    #3.2Matplotlib注解绘制树形图
    #使用文本注解绘制树节点
    import matplotlib
    import matplotlib.pyplot as plt
    
    decisionNode=dict(boxstyle="sawtooth",fc="0.8") #设置点
    leafNode=dict(boxstyle="round4",fc="0.8")   #设置点
    arrow_args=dict(arrowstyle="<-")    #设置箭头
    
    #在图中添加这些点
    def plotNode(nodeTxt,centerPt,parentPt,nodeType):
        #annotate是在plt的subplot上标记的函数
        createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',xytext=centerPt,
                                textcoords='axes fraction',va="center",bbox=nodeType,arrowprops=arrow_args)
    # def createPlot():
    #     fig=plt.figure(1,facecolor='white')#图像编号1,背景色白色
    #     fig.clf() # Clear figure清除所有轴,但是窗口打开,这样它可以被重复使用
    #     createPlot.ax1=plt.subplot(111,frameon=False)# 1行1列,位置是1的子图——createPlot.ax1是plt子图的索引,可以通过ax1设计plt子图
    #     plotNode('决策节点',(0.5,0.1),(0.1,0.5),decisionNode)
    #     plotNode('叶节点',(0.8,0.1),(0.0,0.0),leafNode)
    #     plt.show()
    
    #注意:使用matplotlib时不要用qq输入法
    # createPlot()
    
    #构造注解树
    
    #获取叶节点的数目
    def getNumLeafs(myTree):
        numLeafs=0
        firstStr=list(myTree.keys())[0]
        secondDict=myTree[firstStr]
        for key in secondDict.keys():
            if type(secondDict[key])==dict:
                numLeafs+=getNumLeafs(secondDict[key])
            else:numLeafs+=1
        return numLeafs
    
    #获得树的层数
    def getTreeDepth(myTree):
        maxDepth=0
        firstStr=list(myTree.keys())[0]
        secondDict=myTree[firstStr]
        for key in secondDict.keys():
            if type(secondDict[key])==dict:
                thisDepth=1+getTreeDepth(secondDict[key])
            else:thisDepth=1
            if thisDepth>maxDepth:maxDepth=thisDepth
        return maxDepth
    
    #获得一颗树的数据
    def retrieveTree():
        myDat, labels = createDataSet()
        mytree = createTree(myDat, labels)
        return mytree
    
    # mytree=retrieveTree()
    # print(getNumLeafs(mytree))
    # print(getTreeDepth(mytree))
    
    #plotTree函数
    def plotMidTest(cntrPt,parentPt,txtString):
        xMid=(parentPt[0]-cntrPt[0])/2.0+cntrPt[0]
        yMid=(parentPt[1]-cntrPt[1])/2.0+cntrPt[1]
        createPlot.ax1.text(xMid,yMid,txtString)
    
    def plotTree(myTree,parentPt,nodeTxt):
        numLeafs=getNumLeafs(myTree)
        depth=getTreeDepth(myTree)
        firstStr=list(myTree.keys())[0]
        cntrPt=(plotTree.xOff+(1.0+float(numLeafs))/2.0/plotTree.totalW , plotTree.yOff)
        plotMidTest(cntrPt,parentPt,nodeTxt)
        plotNode(firstStr,cntrPt,parentPt,decisionNode)
        secondDict=myTree[firstStr]
        plotTree.yOff=plotTree.yOff-1.0/plotTree.totalD
        for key in secondDict.keys():
            if type(secondDict[key])==dict:
                plotTree(secondDict[key],cntrPt,str(key))
            else:
                plotTree.xOff=plotTree.xOff+1.0/plotTree.totalW
                plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,decisionNode)
                plotMidTest((plotTree.xOff,plotTree.yOff), cntrPt, str(key))
        plotTree.yOff=plotTree.yOff+1.0/plotTree.totalD
    
    def createPlot(inTree):
        fig=plt.figure(1,facecolor='white')#建立背景色白色图
        fig.clf()#清除框架
        axprops=dict(xticks=[],yticks=[])
        createPlot.ax1=plt.subplot(111,frameon=False,**axprops)#生成子图
        plotTree.totalW=float(getNumLeafs(inTree))#创建变量
        plotTree.totalD=float(getTreeDepth(inTree))#创建变量
        plotTree.xOff=-0.5/plotTree.totalW
        plotTree.yOff=1.0
        plotTree(inTree,(0.5,1.0),'')
        plt.show()
    
    # createPlot(retrieveTree())
    
    # 3-3测试和存储分类器
    def classify(inputTree,featLabels,testVec):#testVec存储着对每个featLabel的答案
        firstStr=list(inputTree.keys())[0]
        secondDict=inputTree[firstStr]
        featIndex=featLabels.index(firstStr)
        for key in secondDict.keys():
            if testVec[featIndex]==key:
                if type(secondDict[key])==dict:
                    classLabel=classify(secondDict[key],featLabels,testVec)
                else:
                    classLabel=secondDict[key]
        return classLabel
    
    #使用算法:决策树的存储
    def storeTree(inputTree,filename):
        import pickle
        fw=open(filename,'wb')  #二进制存
        pickle.dump(inputTree,fw)
        fw.close()
    
    def grabTree(filename):
        import pickle
        fr=open(filename,'rb')  ##二进制取
        return pickle.load(fr)
    
    # myTree=retrieveTree()
    # storeTree(myTree,"E:/Python/PycharmProjects/机器学习实战/Include/第03章_决策树/s.txt")
    # print(grabTree("E:/Python/PycharmProjects/机器学习实战/Include/第03章_决策树/s.txt"))


    #示例:使用决策树预测隐形眼镜的类型
    fr=open("E:/Python/《机器学习实战》代码/Ch03/lenses.txt")
    lenses=[]
    for data in fr.readlines():
    lenses.append(data.strip().split(' '))
    lensesLabels=['age','prescript','astigmatic','tearRate']
    lensesTree=createTree(lenses,lensesLabels)
    createPlot(lensesTree)
     
  • 相关阅读:
    A problem occurred evaluating project ':'. > ASCII
    算法------买卖股票的最佳时机
    源码解析:解析掌阅X2C 框架
    Java 基础------16进制转2进制
    2018年年度总结,以及2019年规划
    算法-----三数之和等于0
    算法--------数组类---------总结
    算法------长度最小的子数组
    算法--------数组--------容纳最多的水
    算法--------数组------反转字符串中的元音字母
  • 原文地址:https://www.cnblogs.com/LPworld/p/13272339.html
Copyright © 2011-2022 走看看