zoukankan      html  css  js  c++  java
  • Python编程实验一 决策树实现结果预测

    题目:给定如下训练集和测试集,参考《机器学习》(Tom Mitchell)第三章和《机器学习》(周志华)第四章,先阅读ID3、C4.5CART算法并且仔细阅读附件给出的ID3、C4.5算法python程序,再实现基于基尼指数(Gini index)选择最优划分属性(特征)构造的CART决策树的 python程序。最终提交一份实验报告提交的实验报告给出python实现的完整程序和实验结果。

    训练集:

        outlook     temperature    humidity     windy

        ---------------------------------------------------------

        sunny       hot            high         false         N

        sunny       hot            high         true          N

        overcast    hot            high         false         Y

        rain        mild           high         false         Y

        rain        cool           normal       false         Y

        rain        cool           normal       true          N

        overcast    cool           normal       true          Y

    测试集

        outlook    temperature     humidity     windy

        ---------------------------------------------------------

        sunny       mild           high         false          

        sunny       cool           normal       false         

        rain        mild           normal       false        

        sunny       mild           normal       true          

        overcast    mild           high         true          

        overcast    hot            normal       false         

        true        rain           mild         high        

     

    Python程序

    1、CART.py

    # -*- coding: utf-8 -*-

    ## 参考《机器学习》(Tom M. Mitchell) 第三章 决策树学习

    ## 《机器学习》(周志华), 第四章 决策树

      1 from math import log
      2 import operator
      3 import treePlotter
      4 
      5 def calcGiniIndex(dataSet):
      6     """
      7     输入:数据集
      8     输出:数据集的基尼指数
      9     描述:计算给定数据集的基尼指数
     10     """
     11     numEntries = len(dataSet)                     # 返回数据集的行数
     12     labelCounts = {}                              # 保存每个标签(Label)出现次数的字典
     13     for featVec in dataSet:                       # 对每组特征向量进行统计
     14         currentLabel = featVec[-1]                  # 提取标签信息
     15         if currentLabel not in labelCounts.keys():    # 如果标签没有放入统计次数的字典,添加进去
     16             labelCounts[currentLabel] = 0
     17         labelCounts[currentLabel] += 1            # Label计数
     18     giniIndexEnt = 0.0                            # 基尼指数
     19     for key in labelCounts:                          # 计算基尼指数
     20         prob = float(labelCounts[key])/numEntries       # 选择该标签(Label)的概率
     21         giniIndexEnt += prob * (1.0 - prob)       # 利用公式计算
     22     return giniIndexEnt                              # 返回基尼指数
     23 
     24 def splitDataSet(dataSet, axis, value):       # 待划分数据集合,特征下标,特征值
     25     """
     26     输入:数据集,选择维度,选择值
     27     输出:划分数据集
     28     描述:按照给定特征划分数据集;去除选择维度中等于选择值的项
     29     """
     30     retDataSet = []                          # 保存划分的数据子集
     31     for featVec in dataSet:                  # 遍历数据集中的每个样本
     32         if featVec[axis] == value:           #如果特征值符合要求,则添加到子集中
     33             reduceFeatVec = featVec[:axis]   # 保存第0到第axis-1个特征
     34             reduceFeatVec.extend(featVec[axis+1:])     # 保存第axis+1到最后一个特征
     35             retDataSet.append(reduceFeatVec)           # 添加符合要求的样本到划分子集中
     36     return retDataSet                                  # 返回划分好的(特征axis的值=value)的子集
     37 
     38 def chooseBestFeatureToSplit(dataSet):
     39     """
     40     输入:数据集
     41     输出:最好的划分维度
     42     描述:选择最好的数据集划分维度
     43     """
     44     numFeatures = len(dataSet[0]) - 1                 # 特征数量
     45     bestInfoGini = calcGiniIndex(dataSet)             # 计算数据集的基尼指数
     46     bestFeature = -1                                   # 最优特征索引值
     47     for i in range(numFeatures):                        # 遍历所有特征
     48         featList = [example[i] for example in dataSet]      # 获取dataSet的第i个所有特征-第i列全部特征
     49         uniqueVals = set(featList)                    # 创建set集合{}元素不可重复
     50         newGini = 0.0
     51         for value in uniqueVals:                            # 计算新的基尼指数
     52             subDataSet = splitDataSet(dataSet, i, value)    # subDataSet划分后的子集
     53             prob = len(subDataSet)/float(len(dataSet))      # 计算子集概率
     54             newGini += prob * calcGiniIndex(subDataSet)     # 新的基尼指数
     55         if (newGini < bestInfoGini):
     56             bestInfoGini = newGini
     57             bestFeature = i
     58     return bestFeature                                         # 返回基尼指数最小的特征索引值
     59 
     60 def majorityCnt(classList):
     61     """
     62     输入:分类类别列表
     63     输出:子节点的分类
     64     描述:数据集已经处理了所有属性,但是类标签依然不是唯一的,
     65           采用多数判决的方法决定该子节点的分类
     66     """
     67     classCount = {}
     68     for vote in classList:                # 统计classList中每个元素出现的次数
     69         if vote not in classCount.keys():
     70             classCount[vote] = 0
     71         classCount[vote] += 1
     72     sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reversed=True)  # 根据字典的值降序排序
     73     return sortedClassCount[0][0]            # 返回classList中出现次数最多的元素
     74 
     75 def createTree(dataSet, labels):
     76     """
     77     输入:数据集,特征标签
     78     输出:决策树
     79     描述:递归构建决策树,利用上述的函数
     80     """
     81     classList = [example[-1] for example in dataSet]
     82     if classList.count(classList[0]) == len(classList):
     83         # 类别完全相同,停止划分
     84         return classList[0]
     85     if len(dataSet[0]) == 1:
     86         # 遍历完所有特征时返回出现次数最多的
     87         return majorityCnt(classList)
     88     bestFeat = chooseBestFeatureToSplit(dataSet)
     89     bestFeatLabel = labels[bestFeat]
     90     myTree = {bestFeatLabel:{}}
     91     del(labels[bestFeat])
     92     # 得到列表包括节点所有的属性值
     93     featValues = [example[bestFeat] for example in dataSet]
     94     uniqueVals = set(featValues)
     95     for value in uniqueVals:
     96         subLabels = labels[:]
     97         myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
     98     return myTree
     99 
    100 def classify(inputTree, featLabels, testVec):
    101     """
    102     输入:决策树,分类标签,测试数据
    103     输出:决策结果
    104     描述:跑决策树
    105     """
    106     firstStr = list(inputTree.keys())[0]
    107     secondDict = inputTree[firstStr]
    108     featIndex = featLabels.index(firstStr)
    109     for key in secondDict.keys():
    110         if testVec[featIndex] == key:
    111             if type(secondDict[key]).__name__ == 'dict':
    112                 classLabel = classify(secondDict[key], featLabels, testVec)
    113             else:
    114                 classLabel = secondDict[key]
    115     return classLabel
    116 
    117 def classifyAll(inputTree, featLabels, testDataSet):
    118     """
    119     输入:决策树,分类标签,测试数据集
    120     输出:决策结果
    121     描述:跑决策树
    122     """
    123     classLabelAll = []
    124     for testVec in testDataSet:
    125         classLabelAll.append(classify(inputTree, featLabels, testVec))
    126     return classLabelAll
    127 
    128 def storeTree(inputTree, filename):
    129     """
    130     输入:决策树,保存文件路径
    131     输出:
    132     描述:保存决策树到文件
    133     """
    134     import pickle
    135     fw = open(filename, 'wb')
    136     pickle.dump(inputTree, fw)
    137     fw.close()
    138 
    139 def grabTree(filename):
    140     """
    141     输入:文件路径名
    142     输出:决策树
    143     描述:从文件读取决策树
    144     """
    145     import pickle
    146     fr = open(filename, 'rb')
    147     return pickle.load(fr)
    148 
    149 def createDataSet():
    150     """
    151     outlook->  0: sunny | 1: overcast | 2: rain
    152     temperature-> 0: hot | 1: mild | 2: cool
    153     humidity-> 0: high | 1: normal
    154     windy-> 0: false | 1: true
    155     """
    156     dataSet = [[0, 0, 0, 0, 'N'],
    157                [0, 0, 0, 1, 'N'],
    158                [1, 0, 0, 0, 'Y'],
    159                [2, 1, 0, 0, 'Y'],
    160                [2, 2, 1, 0, 'Y'],
    161                [2, 2, 1, 1, 'N'],
    162                [1, 2, 1, 1, 'Y']]
    163     labels = ['outlook', 'temperature', 'humidity', 'windy']
    164     return dataSet, labels
    165 
    166 def createTestSet():
    167     """
    168     outlook->  0: sunny | 1: overcast | 2: rain
    169     temperature-> 0: hot | 1: mild | 2: cool
    170     humidity-> 0: high | 1: normal
    171     windy-> 0: false | 1: true
    172     """
    173     testSet = [[0, 1, 0, 0],
    174                [0, 2, 1, 0],
    175                [2, 1, 1, 0],
    176                [0, 1, 1, 1],
    177                [1, 1, 0, 1],
    178                [1, 0, 1, 0],
    179                [2, 1, 0, 1]]
    180     return testSet
    181 
    182 def main():
    183     dataSet, labels = createDataSet()
    184     labels_tmp = labels[:] # 拷贝,createTree会改变labels
    185     desicionTree = createTree(dataSet, labels_tmp)
    186     #storeTree(desicionTree, 'classifierStorage.txt')
    187     #desicionTree = grabTree('classifierStorage.txt')
    188     print('desicionTree:
    ', desicionTree)
    189     treePlotter.createPlot(desicionTree)
    190     testSet = createTestSet()
    191     print('classifyResult:
    ', classifyAll(desicionTree, labels, testSet))
    192 
    193 if __name__ == '__main__':
    194     main()
    195 
    196 2、treePlotter.py
    197 import matplotlib.pyplot as plt
    198 
    199 decisionNode = dict(boxstyle="sawtooth", fc="0.8")
    200 leafNode = dict(boxstyle="round4", fc="0.8")
    201 arrow_args = dict(arrowstyle="<-")
    202 
    203 def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    204     """
    205     输入:
    206     输出:
    207     描述:绘制一个点
    208     """
    209     createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', 
    210                             xytext=centerPt, textcoords='axes fraction', 
    211                             va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
    212 
    213 def getNumLeafs(myTree):
    214     """
    215     输入:决策树
    216     输出:决策树的叶子数量
    217     描述:
    218     """
    219     numLeafs = 0
    220     firstStr = list(myTree.keys())[0]
    221     secondDict = myTree[firstStr]
    222     for key in secondDict.keys():
    223         if type(secondDict[key]).__name__ == 'dict':
    224             numLeafs += getNumLeafs(secondDict[key])
    225         else:
    226             numLeafs += 1
    227     return numLeafs
    228 
    229 def getTreeDepth(myTree):
    230     """
    231     输入:决策树
    232     输出:树的深度
    233     描述:
    234     """
    235     maxDepth = 0
    236     firstStr = list(myTree.keys())[0]
    237     secondDict = myTree[firstStr]
    238     for key in secondDict.keys():
    239         if type(secondDict[key]).__name__ == 'dict':
    240             thisDepth = getTreeDepth(secondDict[key]) + 1
    241         else:
    242             thisDepth = 1
    243         if thisDepth > maxDepth:
    244             maxDepth = thisDepth
    245     return maxDepth
    246 
    247 def plotMidText(cntrPt, parentPt, txtString):
    248     """
    249     输入:
    250     输出:
    251     描述:
    252     """
    253     xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
    254     yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
    255     createPlot.ax1.text(xMid, yMid, txtString)
    256 
    257 def plotTree(myTree, parentPt, nodeTxt):
    258     """
    259     输入:
    260     输出:
    261     描述:
    262     """
    263     numLeafs = getNumLeafs(myTree)
    264     depth = getTreeDepth(myTree)
    265     firstStr = list(myTree.keys())[0]
    266     cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalw, plotTree.yOff)
    267     plotMidText(cntrPt, parentPt, nodeTxt)
    268     plotNode(firstStr, cntrPt, parentPt, decisionNode)
    269     secondDict = myTree[firstStr]
    270     plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
    271     for key in secondDict.keys():
    272         if type(secondDict[key]).__name__ == 'dict':
    273             plotTree(secondDict[key], cntrPt, str(key))
    274         else:
    275             plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalw
    276             plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
    277             plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    278     plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD
    279 
    280 def createPlot(inTree):
    281     """
    282     输入:决策树
    283     输出:
    284     描述:绘制整个决策树
    285     """
    286     fig = plt.figure(1, facecolor='white')
    287     fig.clf()
    288     axprops = dict(xticks=[], yticks=[])
    289     createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
    290     plotTree.totalw = float(getNumLeafs(inTree))
    291     plotTree.totalD = float(getTreeDepth(inTree))
    292     plotTree.xOff = -0.5 / plotTree.totalw
    293     plotTree.yOff = 1.0
    294     plotTree(inTree, (0.5, 1.0), '')
    295     plt.show()

    运行结果:

    desicionTree:

     {'outlook': {0: 'N', 1: 'Y', 2: {'windy': {0: 'Y', 1: 'N'}}}}

    classifyResult:

     ['N', 'N', 'Y', 'N', 'Y', 'Y', 'N']

    Process finished with exit code 0

    画出的决策树图形:

       

  • 相关阅读:
    SQL Server没有足够的内存继续执行程序 (mscorlib)的解决办法
    在IIS上搭建WebSocket服务器(一)
    端口号被占用
    2018年 年度总结
    一个人颓废的九大根源
    Arrays.asList() 踩坑
    电脑关机命令
    div 悬浮
    ajax中 踩过的坑
    oracle 密码过期问题
  • 原文地址:https://www.cnblogs.com/ku1274755259/p/11108940.html
Copyright © 2011-2022 走看看