zoukankan      html  css  js  c++  java
  • day-8 python自带库实现ID3决策树算法

      前一天,我们基于sklearn科学库实现了ID3的决策树程序,本文将基于python自带库实现ID3决策树算法。

     一、代码涉及基本知识

      1、 为了绘图方便,引入了一个第三方treePlotter模块进行图形绘制。该模块使用方法简单,调用模块createPlot接口,传入一个树型结构对象,即可绘制出相应图像。

      2、  在python中,如何定义一个树型结构对象

        可以使用了python自带的字典数据类型来定义一个树型对象。例如下面代码,我们定义一个根节点和两个左右子节点:

        rootNode = {'rootNode': {}}
        leftNode = {'leftNode': {'yes':'yes'}}
        rightNode = {'rightNode': {'no':'no'}}
        rootNode['rootNode']['left'] = leftNode
        rootNode['rootNode']['right'] = rightNode
        treePlotter.createPlot(rootNode)

        通过调用treePlotter模块,绘制出如下树的图像

        

      2、  递归调用

        为了求每个节点的各个子节点,要用到递归的方法来实现,基本思想和二叉树的遍历方法一致,后面我们还会用Python实现一个二叉树源码,此处不再进行介绍。

      3、  此外,还需要对python常用的数据类型及其操作比较了解,例如字典、列表、集合等

    二、程序主要流程

     

    三、测试数据集

    age

    income

    student

    credit_rating

    class_buys_computer

    youth

    high

    no

    fair

    no

    youth

    high

    no

    excellent

    no

    middle_aged

    high

    no

    fair

    yes

    senior

    medium

    no

    fair

    yes

    senior

    low

    yes

    fair

    yes

    senior

    low

    yes

    excellent

    no

    middle_aged

    low

    yes

    excellent

    yes

    youth

    medium

    no

    fair

    no

    youth

    low

    yes

    fair

    yes

    senior

    medium

    yes

    fair

    yes

    youth

    medium

    yes

    excellent

    yes

    middle_aged

    medium

    no

    excellent

    yes

    middle_aged

    high

    yes

    fair

    yes

    senior

    medium

    no

    excellent

    no

    四、程序代码

             1、计算测试集熵及信息增益        

    # 求最优的根节点
    def chooseBestFeatureToSplit(dataset,headerList):
        # 定义一个初始值
        bestInfoGainRate = 0.0
        bestFeature = 0
        # 求特征列项的数量
        numFeatures = len(dataset[0]) -1
        # 获取整个测试数据集的熵
        baseShnnonEnt = calcShannonEnt(dataset)
        print("total's shannonEnt = %f" % (baseShnnonEnt))
        # 遍历每一个特征列,求取信息增益
        for i in range(numFeatures):
            # 获取某一列所有特征值
            featureList = [example[i] for example in dataset]
            uniqueVals = set(featureList)
            newEntropy = 0.0
            # 求得某一列某一个特征值的概率和熵
            newShannonEnt = 0.0
            for value in uniqueVals:
                # 计算熵
                subDataset = splitDataSet(dataset,i,value)
                newEntropy = calcShannonEnt(subDataset)
                # 计算某一列某一个特征值的概率
                newProbability = len(subDataset) / float(len(dataset))
                newShannonEnt += newProbability*calcShannonEnt(subDataset)
            infoGainRate = baseShnnonEnt - newShannonEnt
            print("%s's infoGainRate = %f - %f = %f"%(headerList[i],baseShnnonEnt,newShannonEnt,infoGainRate))
            if infoGainRate > bestInfoGainRate:
                bestInfoGainRate = infoGainRate
                bestFeature = i
        return bestFeature

      该结果和前一天计算结果一致,age特征对应信息增益最大,因此设为根节点:

            

             2、程序源码

             treePlotter.py        

    import matplotlib.pyplot as plt
    
    # 定义决策树决策结果属性
    descisionNode = dict(boxstyle='sawtooth', fc='0.8')
    leafNode = dict(boxstyle='round4', fc='0.8')
    arrow_args = dict(arrowstyle='<-')
    
    def plotNode(nodeTxt, centerPt, parentPt, nodeType):
        # nodeTxt为要显示的文本,centerNode为文本中心点, nodeType为箭头所在的点, parentPt为指向文本的点
        createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
                                 xytext=centerPt, textcoords='axes fraction',
                                  va='center', ha='center', bbox=nodeType, arrowprops=arrow_args)
    def getNumLeafs(myTree):
        numLeafs = 0
        firstStr = list(myTree.keys())[0]
        secondDict = myTree[firstStr]
        for key in secondDict.keys():
            if type(secondDict[key]).__name__ == 'dict':
                numLeafs += getNumLeafs(secondDict[key])
            else:
                numLeafs += 1
        return numLeafs
    
    def getTreeDepth(myTree):
        maxDepth = 0
        firstStr = list(myTree.keys())[0]     # 这个是改的地方,原来myTree.keys()返回的是dict_keys类,不是列表,运行会报错。有好几个地方这样
        secondDict = myTree[firstStr]
        for key in secondDict.keys():
            if type(secondDict[key]).__name__ == 'dict':
                thisDepth = 1 + getTreeDepth(secondDict[key])
            else:
                thisDepth = 1
            if thisDepth > maxDepth:
                maxDepth = thisDepth
        return maxDepth
    
    def createPlot(inTree):
        fig = plt.figure(1, facecolor='white')
        fig.clf()
        axprops = {'xticks': None, 'yticks': None}
        createPlot.ax1 = plt.subplot(111, frameon=False)
        plotTree.totalW = float(getNumLeafs(inTree))     # 全局变量宽度 = 叶子数目
        plotTree.totalD = float(getTreeDepth(inTree))     # 全局变量高度 = 深度
        plotTree.xOff = -0.5/plotTree.totalW
        plotTree.yOff = 1.0
        plotTree(inTree, (0.5, 1.0), '')
        plt.show()
    
    def plotTree(myTree, parentPt, nodeTxt):
        numLeafs = getNumLeafs(myTree)
        depth = getTreeDepth(myTree)
        firstStr = list(myTree.keys())[0]
        # cntrPt文本中心点, parentPt指向文本中心的点
        cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
        plotMidText(cntrPt, parentPt, nodeTxt)
        plotNode(firstStr, cntrPt, parentPt, descisionNode)
        seconDict = myTree[firstStr]
        plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
        for key in seconDict.keys():
            if type(seconDict[key]).__name__ == 'dict':
                plotTree(seconDict[key], cntrPt, str(key))
            else:
                plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
                plotNode(seconDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
                plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
        plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
    
    def plotMidText(cntrPt, parentPt, txtString):
        xMid = (parentPt[0] - cntrPt[0])/2.0 + cntrPt[0]
        yMid = (parentPt[1] - cntrPt[1])/2.0 + cntrPt[1]
        createPlot.ax1.text(xMid, yMid, txtString, va='center', ha='center', rotation=30)

             decision_tree_ID3.py

    # 导入库
    import csv
    import math
    import operator
    import treePlotter
    
    
    # 导入数据集
    def readDataset(file_path,file_mode):
        allElectronicsData = open(file_path, file_mode)
        reader = csv.reader(allElectronicsData)
        # 读取特征名称
        headers = next(reader)
        # 读取测试数据集
        dataset = []
        for row in reader:
            dataset.append(row)
        return headers,dataset
    
    # 求某个数据集的熵
    def calcShannonEnt(dataset):
        shannonEnt = 0.0
        labelList = {}
        for vec_now in dataset:
            labelValue = vec_now[-1]
            if vec_now[-1] not in labelList.keys():
                labelList[labelValue] = 0
            labelList[labelValue] += 1
        for labelKey in labelList:
            probability = float(labelList[labelKey] / len(dataset))
            shannonEnt -= probability*math.log(probability,2)
        return shannonEnt
    
    # 根据给定的列特征值,分理出给定的特征量
    def splitDataSet(dataset,feature_seq,value):
        new_dataset = []
        for vec_row in dataset:
            feature_Value = vec_row[feature_seq]
            if feature_Value == value:
                temp_vec = []
                temp_vec = vec_row[:feature_seq]
                temp_vec.extend(vec_row[feature_seq+1:])
                new_dataset.append(temp_vec)
        return new_dataset
    
    # 求最优的根节点
    def chooseBestFeatureToSplit(dataset,headerList):
        # 定义一个初始值
        bestInfoGainRate = 0.0
        bestFeature = 0
        # 求特征列项的数量
        numFeatures = len(dataset[0]) -1
        # 获取整个测试数据集的熵
        baseShnnonEnt = calcShannonEnt(dataset)
        #print("total's shannonEnt = %f" % (baseShnnonEnt))
        # 遍历每一个特征列,求取信息增益
        for i in range(numFeatures):
            # 获取某一列所有特征值
            featureList = [example[i] for example in dataset]
            uniqueVals = set(featureList)
            newEntropy = 0.0
            # 求得某一列某一个特征值的概率和熵
            newShannonEnt = 0.0
            for value in uniqueVals:
                # 计算熵
                subDataset = splitDataSet(dataset,i,value)
                newEntropy = calcShannonEnt(subDataset)
                # 计算某一列某一个特征值的概率
                newProbability = len(subDataset) / float(len(dataset))
                newShannonEnt += newProbability*calcShannonEnt(subDataset)
            infoGainRate = baseShnnonEnt - newShannonEnt
            #print("%s's infoGainRate = %f - %f = %f"%(headerList[i],baseShnnonEnt,newShannonEnt,infoGainRate))
            if infoGainRate > bestInfoGainRate:
                bestInfoGainRate = infoGainRate
                bestFeature = i
        return bestFeature
    
    # 标签判定,通过少数服从多数原则
    def majorityCnt(classList):
        classcount = {}
        for cl in classList:
            if cl not in classcount.keys():
                classcount[cl] = 0
            classcount[cl] += 1
        sortedClassCount = sorted(classcount.items(),key = operator.itemgetter(1),reverse= True)
        return sortedClassCount[0][0]
    
    # 创建一个决策树
    def createTree(dataSet, labels):
        classList = [example[-1] for example in dataSet]
        # 1 所有特征值都是相同的时候直接返回
        if classList.count(classList[0]) == len(classList):
            return classList[0]
        # 2 遍历完所有特征值,投票原则,返回出现次数最多的标签
        if len(dataSet[0])  == 1:
            return majorityCnt(classList)
        # 3 如果不满足上面两者,求最优特征
        bestFeature = chooseBestFeatureToSplit(dataSet,labels)
        bestFeatureLabel = labels[bestFeature]
        myTree = {bestFeatureLabel: {}}
        del (labels[bestFeature])
        featurValues = [example[bestFeature] for example in dataSet]
        uniqueVals = set(featurValues)
        # 使用递归的方法,获得整个树
        for value in uniqueVals:
            subLabels = labels[:]
            myTree[bestFeatureLabel][value] = createTree(splitDataSet(dataSet, bestFeature, value), subLabels)
        return myTree
    
    def classify(inputTree, featLabels, testVec):
        firstStr = list(inputTree.keys())[0]
        secondDict = inputTree[firstStr]
        featIndex = featLabels.index(firstStr)
        for key in secondDict.keys():
            if testVec[featIndex] == key:
                if type(secondDict[key]).__name__ == 'dict':
                    classLabel = classify(secondDict[key], featLabels, testVec)
                else:
                    classLabel = secondDict[key]
        return classLabel
    
    def classifyAll(inputTree, featLabels, testDataSet):
        classLabelAll = []
        for testVec in testDataSet:
            classLabelAll.append(classify(inputTree, featLabels, testVec))
        return classLabelAll
    
    def storeTree(inputTree, filename):
        import pickle
        fw = open(filename, 'wb')
        pickle.dump(inputTree, fw)
        fw.close()
    
    def grabTree(filename):
        import pickle
        fr = open(filename, 'rb')
        return pickle.load(fr)
    
    def main():
        # 读取数据集
        labels, dataSet = readDataset(file_path=r'D:	est.csv', file_mode='r')
        labels_tmp = labels[:] # 拷贝,createTree会改变labels
        desicionTree = createTree(dataSet, labels_tmp)
        storeTree(desicionTree, 'classifierStorage.txt')
        desicionTree = grabTree('classifierStorage.txt')
        treePlotter.createPlot(desicionTree)
        testSet = [['youth', 'high', 'no', 'fair', 'no']]
        print('classifyResult:
    ', classifyAll(desicionTree, labels, testSet))
    
    if __name__ == '__main__':
        main()

    五、测试结果及结论

     

      我们从上面求解信息增益的公式中,其实可以看出,信息增益准则其实是对可取值数目较多的属性有所偏好!
      现在假如我们把数据集中的“编号”也作为一个候选划分属性。我们可以算出“编号”的信息增益是0.998
      因为每一个样本的编号都是不同的(由于编号独特唯一,条件熵为0了,每一个结点中只有一类,纯度非常高啊),也就是说,来了一个预测样本,你只要告诉我编号,其它特征就没有用了,这样生成的决策树显然不具有泛化能力。

      参考链接:

      http://www.cnblogs.com/wsine/p/5180310.html

      https://zhuanlan.zhihu.com/p/26760551

  • 相关阅读:
    LeetCode OJ String to Integer (atoi) 字符串转数字
    HDU 1005 Number Sequence(AC代码)
    HDU 1004 Let the Balloon Rise(AC代码)
    HDU 1003 Max Sum(AC代码)
    012 Integer to Roman 整数转换成罗马数字
    011 Container With Most Water 盛最多水的容器
    010 Regular Expression Matching 正则表达式匹配
    007 Reverse Integer 旋转整数
    006 ZigZag Conversion
    005 Longest Palindromic Substring 最长回文子串
  • 原文地址:https://www.cnblogs.com/python-frog/p/8721065.html
Copyright © 2011-2022 走看看