zoukankan      html  css  js  c++  java
  • 决策树的构造(代码)

    
    
      1 from matplotlib.font_manager import FontProperties
      2 import matplotlib.pyplot as plt
      3 from math import log
      4 import operator
      5 """
      6 函数说明:创建测试数据集
      7 Parameters:
      8   9 Returns:
     10     dataSet - 数据集
     11     labels - 分类属性
     12 """
     13 def createDataSet():
     14     dataSet = [[0, 0, 0, 0, 'no'],         #数据集
     15             [0, 0, 0, 1, 'no'],
     16             [0, 1, 0, 1, 'yes'],
     17             [0, 1, 1, 0, 'yes'],
     18             [0, 0, 0, 0, 'no'],
     19             [1, 0, 0, 0, 'no'],
     20             [1, 0, 0, 1, 'no'],
     21             [1, 1, 1, 1, 'yes'],
     22             [1, 0, 1, 2, 'yes'],
     23             [1, 0, 1, 2, 'yes'],
     24             [2, 0, 1, 2, 'yes'],
     25             [2, 0, 1, 1, 'yes'],
     26             [2, 1, 0, 1, 'yes'],
     27             [2, 1, 0, 2, 'yes'],
     28             [2, 0, 0, 0, 'no']]
     29     labels = ['年龄', '有工作', '有自己的房子', '信贷情况']        #分类属性
     30     return dataSet, labels                #返回数据集和分类属性
     31 """
     32 函数说明:计算给定数据集的经验熵(香农熵)
     33 Parameters:
     34     dataSet - 数据集
     35 Returns:
     36     shannonEnt - 经验熵(香农熵)
     37 """
     38 def calcShannonEnt(dataSet):
     39     numEntires = len(dataSet)                        #返回数据集的行数
     40     labelCounts = {}                                #保存每个标签(Label)出现次数的字典
     41     for featVec in dataSet:                            #对每组特征向量进行统计
     42         currentLabel = featVec[-1]                    #提取标签(Label)信息
     43         if currentLabel not in labelCounts.keys():    #如果标签(Label)没有放入统计次数的字典,添加进去
     44             labelCounts[currentLabel] = 0
     45         labelCounts[currentLabel] += 1                #Label计数
     46     shannonEnt = 0.0                                #经验熵(香农熵)
     47     for key in labelCounts:                            #计算香农熵
     48         prob = float(labelCounts[key]) / numEntires    #选择该标签(Label)的概率
     49         shannonEnt -= prob * log(prob, 2)            #利用公式计算
     50     return shannonEnt                                #返回经验熵(香农熵)
     51 """
     52 函数说明:创建测试数据集
     53 Parameters:
     54  55 Returns:
     56     dataSet - 数据集
     57     labels - 分类属性
     58 """
     59 def createDataSet():
     60     dataSet = [[0, 0, 0, 0, 'no'],                        #数据集
     61             [0, 0, 0, 1, 'no'],
     62             [0, 1, 0, 1, 'yes'],
     63             [0, 1, 1, 0, 'yes'],
     64             [0, 0, 0, 0, 'no'],
     65             [1, 0, 0, 0, 'no'],
     66             [1, 0, 0, 1, 'no'],
     67             [1, 1, 1, 1, 'yes'],
     68             [1, 0, 1, 2, 'yes'],
     69             [1, 0, 1, 2, 'yes'],
     70             [2, 0, 1, 2, 'yes'],
     71             [2, 0, 1, 1, 'yes'],
     72             [2, 1, 0, 1, 'yes'],
     73             [2, 1, 0, 2, 'yes'],
     74             [2, 0, 0, 0, 'no']]
     75     labels = ['年龄', '有工作', '有自己的房子', '信贷情况']        #分类属性
     76     return dataSet, labels                             #返回数据集和分类属性
     77 """
     78 函数说明:按照给定特征划分数据集
     79 Parameters:
     80     dataSet - 待划分的数据集
     81     axis - 划分数据集的特征
     82     value - 需要返回的特征的值
     83 Returns:
     84  85 """
     86 def splitDataSet(dataSet, axis, value):
     87     retDataSet = []                                        #创建返回的数据集列表
     88     for featVec in dataSet:                             #遍历数据集
     89         if featVec[axis] == value:
     90             reducedFeatVec = featVec[:axis]                #去掉axis特征
     91             reducedFeatVec.extend(featVec[axis+1:])     #将符合条件的添加到返回的数据集
     92             retDataSet.append(reducedFeatVec)
     93     return retDataSet                                      #返回划分后的数据集
     94 """
     95 函数说明:选择最优特征
     96 Parameters:
     97     dataSet - 数据集
     98 Returns:
     99     bestFeature - 信息增益最大的(最优)特征的索引值
    100 """
    101 def chooseBestFeatureToSplit(dataSet):
    102     numFeatures = len(dataSet[0]) - 1                    #特征数量
    103     baseEntropy = calcShannonEnt(dataSet)                 #计算数据集的香农熵
    104     bestInfoGain = 0.0                                  #信息增益
    105     bestFeature = -1                                    #最优特征的索引值
    106     for i in range(numFeatures):                         #遍历所有特征
    107         #获取dataSet的第i个所有特征
    108         featList = [example[i] for example in dataSet]
    109         uniqueVals = set(featList)                         #创建set集合{},元素不可重复
    110         newEntropy = 0.0                                  #经验条件熵
    111         for value in uniqueVals:                         #计算信息增益
    112             subDataSet = splitDataSet(dataSet, i, value)         #subDataSet划分后的子集
    113             prob = len(subDataSet) / float(len(dataSet))           #计算子集的概率
    114             newEntropy += prob * calcShannonEnt(subDataSet)     #根据公式计算经验条件熵
    115         infoGain = baseEntropy - newEntropy                     #信息增益
    116         print("第%d个特征的增益为%.3f" % (i, infoGain))            #打印每个特征的信息增益
    117         if (infoGain > bestInfoGain):                             #计算信息增益
    118             bestInfoGain = infoGain                             #更新信息增益,找到最大的信息增益
    119             bestFeature = i                                     #记录信息增益最大的特征的索引值
    120     return bestFeature                                             #返回信息增益最大的特征的索引值
    121 
    122 # if __name__ == '__main__':
    123 #     dataSet, features = createDataSet()
    124 #     print("最优特征索引值:" + str(chooseBestFeatureToSplit(dataSet)))
    125 
    126 # if __name__ == '__main__':
    127 #     dataSet, features = createDataSet()
    128 #     print(dataSet)
    129 #     print(calcShannonEnt(dataSet))
    130 
    131 #递归构建决策树
    132 """
    133 函数说明:统计classList中出现此处最多的元素(类标签)
    134 Parameters:
    135     classList - 类标签列表
    136 Returns:
    137     sortedClassCount[0][0] - 出现此处最多的元素(类标签)
    138 """
    139 def majorityCnt(classList):
    140     classCount = {}
    141     for vote in classList:                                        #统计classList中每个元素出现的次数
    142         if vote not in classCount.keys():classCount[vote] = 0
    143         classCount[vote] += 1
    144     sortedClassCount = sorted(classCount.items(), key = operator.itemgetter(1), reverse = True)        #根据字典的值降序排序
    145     return sortedClassCount[0][0]                                #返回classList中出现次数最多的元素
    146 """
    147 函数说明:创建决策树
    148 Parameters:
    149     dataSet - 训练数据集
    150     labels - 分类属性标签
    151     featLabels - 存储选择的最优特征标签
    152 Returns:
    153     myTree - 决策树
    154 """
    155 def createTree(dataSet, labels, featLabels):
    156     classList = [example[-1] for example in dataSet]            #取分类标签(是否放贷:yes or no)
    157     if classList.count(classList[0]) == len(classList):            #如果类别完全相同则停止继续划分
    158         return classList[0]
    159     if len(dataSet[0]) == 1:                                    #遍历完所有特征时返回出现次数最多的类标签
    160         return majorityCnt(classList)
    161     bestFeat = chooseBestFeatureToSplit(dataSet)                #选择最优特征
    162     bestFeatLabel = labels[bestFeat]                            #最优特征的标签
    163     featLabels.append(bestFeatLabel)
    164     myTree = {bestFeatLabel:{}}                                    #根据最优特征的标签生成树
    165     del(labels[bestFeat])                                        #删除已经使用特征标签
    166     featValues = [example[bestFeat] for example in dataSet]        #得到训练集中所有最优特征的属性值
    167     uniqueVals = set(featValues)                                #去掉重复的属性值
    168     for value in uniqueVals:                                    #遍历特征,创建决策树。
    169         myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), labels, featLabels)
    170     return myTree
    171 
    172 # if __name__ == '__main__':
    173 #     dataSet, labels = createDataSet()
    174 #     featLabels = []
    175 #     myTree = createTree(dataSet, labels, featLabels)
    176 #     print(myTree)
    177 
    178 #决策树可视化
    179 """
    180 函数说明:获取决策树叶子结点的数目
    181 Parameters:
    182     myTree - 决策树
    183 Returns:
    184     numLeafs - 决策树的叶子结点的数目
    185 """
    186 def getNumLeafs(myTree):
    187     numLeafs = 0                                                #初始化叶子
    188     firstStr = next(iter(myTree))                                #python3中myTree.keys()返回的是dict_keys,不在是list,所以不能使用myTree.keys()[0]的方法获取结点属性,可以使用list(myTree.keys())[0]
    189     secondDict = myTree[firstStr]                                #获取下一组字典
    190     for key in secondDict.keys():
    191         if type(secondDict[key]).__name__=='dict':                #测试该结点是否为字典,如果不是字典,代表此结点为叶子结点
    192             numLeafs += getNumLeafs(secondDict[key])
    193         else:   numLeafs +=1
    194     return numLeafs
    195 """
    196 函数说明:获取决策树的层数
    197 Parameters:
    198     myTree - 决策树
    199 Returns:
    200     maxDepth - 决策树的层数
    201 """
    202 def getTreeDepth(myTree):
    203     maxDepth = 0                                                #初始化决策树深度
    204     firstStr = next(iter(myTree))                                #python3中myTree.keys()返回的是dict_keys,不在是list,所以不能使用myTree.keys()[0]的方法获取结点属性,可以使用list(myTree.keys())[0]
    205     secondDict = myTree[firstStr]                                #获取下一个字典
    206     for key in secondDict.keys():
    207         if type(secondDict[key]).__name__=='dict':                #测试该结点是否为字典,如果不是字典,代表此结点为叶子结点
    208             thisDepth = 1 + getTreeDepth(secondDict[key])
    209         else:   thisDepth = 1
    210         if thisDepth > maxDepth: maxDepth = thisDepth            #更新层数
    211     return maxDepth
    212 """
    213 函数说明:绘制结点
    214 Parameters:
    215     nodeTxt - 结点名
    216     centerPt - 文本位置
    217     parentPt - 标注的箭头位置
    218     nodeType - 结点格式
    219 Returns:
    220 221 """
    222 def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    223     arrow_args = dict(arrowstyle="<-")                                            #定义箭头格式
    224     font = FontProperties(fname=r"c:windowsfontssimsun.ttc", size=14)        #设置中文字体
    225     createPlot.ax1.annotate(nodeTxt, xy=parentPt,  xycoords='axes fraction',    #绘制结点
    226         xytext=centerPt, textcoords='axes fraction',
    227         va="center", ha="center", bbox=nodeType, arrowprops=arrow_args, FontProperties=font)
    228 """
    229 函数说明:标注有向边属性值
    230 Parameters:
    231     cntrPt、parentPt - 用于计算标注位置
    232     txtString - 标注的内容
    233 Returns:
    234 235 """
    236 def plotMidText(cntrPt, parentPt, txtString):
    237     xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]                                            #计算标注位置
    238     yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
    239     createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)
    240 """
    241 函数说明:绘制决策树
    242 Parameters:
    243     myTree - 决策树(字典)
    244     parentPt - 标注的内容
    245     nodeTxt - 结点名
    246 Returns:
    247 248 """
    249 def plotTree(myTree, parentPt, nodeTxt):
    250     decisionNode = dict(boxstyle="sawtooth", fc="0.8")                                        #设置结点格式
    251     leafNode = dict(boxstyle="round4", fc="0.8")                                            #设置叶结点格式
    252     numLeafs = getNumLeafs(myTree)                                                          #获取决策树叶结点数目,决定了树的宽度
    253     depth = getTreeDepth(myTree)                                                            #获取决策树层数
    254     firstStr = next(iter(myTree))                                                            #下个字典
    255     cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)    #中心位置
    256     plotMidText(cntrPt, parentPt, nodeTxt)                                                    #标注有向边属性值
    257     plotNode(firstStr, cntrPt, parentPt, decisionNode)                                        #绘制结点
    258     secondDict = myTree[firstStr]                                                            #下一个字典,也就是继续绘制子结点
    259     plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD                                        #y偏移
    260     for key in secondDict.keys():
    261         if type(secondDict[key]).__name__=='dict':                                            #测试该结点是否为字典,如果不是字典,代表此结点为叶子结点
    262             plotTree(secondDict[key],cntrPt,str(key))                                        #不是叶结点,递归调用继续绘制
    263         else:                                                                                #如果是叶结点,绘制叶结点,并标注有向边属性值
    264             plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
    265             plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
    266             plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    267     plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
    268 
    269 """
    270 函数说明:创建绘制面板
    271 Parameters:
    272     inTree - 决策树(字典)
    273 Returns:
    274 275 """
    276 def createPlot(inTree):
    277     fig = plt.figure(1, facecolor='white')                                                    #创建fig
    278     fig.clf()                                                                                #清空fig
    279     axprops = dict(xticks=[], yticks=[])
    280     createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)                                #去掉x、y轴
    281     plotTree.totalW = float(getNumLeafs(inTree))                                            #获取决策树叶结点数目
    282     plotTree.totalD = float(getTreeDepth(inTree))                                            #获取决策树层数
    283     plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;                                #x偏移
    284     plotTree(inTree, (0.5,1.0), '')                                                            #绘制决策树
    285     plt.show()                                                                                 #显示绘制结果
    286 
    287 # if __name__ == '__main__':
    288 #     dataSet, labels = createDataSet()
    289 #     featLabels = []
    290 #     myTree = createTree(dataSet, labels, featLabels)
    291 #     print(myTree)
    292 #     createPlot(myTree)
    293 
    294 #使用决策树进行分类
    295 """
    296 函数说明:使用决策树分类
    297 Parameters:
    298     inputTree - 已经生成的决策树
    299     featLabels - 存储选择的最优特征标签
    300     testVec - 测试数据列表,顺序对应最优特征标签
    301 Returns:
    302     classLabel - 分类结果
    303 """
    304 def classify(inputTree, featLabels, testVec):
    305     firstStr = next(iter(inputTree))                                                        #获取决策树结点
    306     secondDict = inputTree[firstStr]                                                        #下一个字典
    307     featIndex = featLabels.index(firstStr)
    308     for key in secondDict.keys():
    309         if testVec[featIndex] == key:
    310             if type(secondDict[key]).__name__ == 'dict':
    311                 classLabel = classify(secondDict[key], featLabels, testVec)
    312             else: classLabel = secondDict[key]
    313     return classLabel
    314 
    315 if __name__ == '__main__':
    316     dataSet, labels = createDataSet()
    317     featLabels = []
    318     myTree = createTree(dataSet, labels, featLabels)
    319     testVec = [0,1]                                        #测试数据
    320     result = classify(myTree, featLabels, testVec)
    321     if result == 'yes':
    322         print('放贷')
    323     if result == 'no':
    324         print('不放贷')
    325 
    326 #决策树的存储
    327 import pickle
    328 """
    329 函数说明:存储决策树
    330 Parameters:
    331     inputTree - 已经生成的决策树
    332     filename - 决策树的存储文件名
    333 Returns:
    334 335 """
    336 def storeTree(inputTree, filename):
    337     with open(filename, 'wb') as fw:
    338         pickle.dump(inputTree, fw)
    339 
    340 if __name__ == '__main__':
    341     myTree = {'有自己的房子': {0: {'有工作': {0: 'no', 1: 'yes'}}, 1: 'yes'}}
    342     storeTree(myTree, 'classifierStorage.txt')
    343 """
    344 函数说明:读取决策树
    345 Parameters:
    346     filename - 决策树的存储文件名
    347 Returns:
    348     pickle.load(fr) - 决策树字典
    349 """
    350 def grabTree(filename):
    351     fr = open(filename, 'rb')
    352     return pickle.load(fr)
    353 
    354 if __name__ == '__main__':
    355     myTree = grabTree('classifierStorage.txt')
    356     print(myTree)
    357 if __name__ == '__main__':
    358     fr = open('lenses.txt')
    359     lenses = [inst.strip().split('	') for inst in fr.readlines()]
    360     print(lenses)
    361     lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']
    362     myTree_lenses = createTree(lenses, lensesLabels)
    363     createPlot(myTree_lenses)
    
    

    另一个版本代码

      1 from matplotlib.font_manager import FontProperties
      2 import matplotlib.pyplot as plt
      3 from math import log
      4 import operator
      5 import pickle
      6 """
      7 函数说明:计算给定数据集的经验熵(香农熵)
      8 Parameters:
      9     dataSet - 数据集
     10 Returns:
     11     shannonEnt - 经验熵(香农熵)
     12     
     13 函数说明:创建测试数据集
     14 Parameters:
     15  16 Returns:
     17     dataSet - 数据集
     18     labels - 分类属性
     19 """
     20 def createDataSet():
     21     dataSet = [[0, 0, 0, 0, 'no'],                        #数据集
     22                [0, 0, 0, 1, 'no'],
     23                [0, 1, 0, 1, 'yes'],
     24                [0, 1, 1, 0, 'yes'],
     25                [0, 0, 0, 0, 'no'],
     26                [1, 0, 0, 0, 'no'],
     27                [1, 0, 0, 1, 'no'],
     28                [1, 1, 1, 1, 'yes'],
     29                [1, 0, 1, 2, 'yes'],
     30                [1, 0, 1, 2, 'yes'],
     31                [2, 0, 1, 2, 'yes'],
     32                [2, 0, 1, 1, 'yes'],
     33                [2, 1, 0, 1, 'yes'],
     34                [2, 1, 0, 2, 'yes'],
     35                [2, 0, 0, 0, 'no']]
     36     labels = ['年龄', '有工作', '有自己的房子', '信贷情况']        #分类属性
     37     return dataSet, labels                             #返回数据集和分类属性
     38 
     39 '''
     40 函数说明:按照给定特征划分数据集
     41 Parameters:
     42     dataSet - 待划分的数据集
     43     axis - 划分数据集的特征
     44     value - 需要返回的特征的值
     45 '''
     46 def calcShannonEnt(dataSet):
     47     numEntires = len(dataSet)                        #返回数据集的行数
     48     labelCounts = {}                                #保存每个标签(Label)出现次数的字典
     49     for featVec in dataSet:                            #对每组特征向量进行统计
     50         currentLabel = featVec[-1]                    #提取标签(Label)信息
     51         if currentLabel not in labelCounts.keys():    #如果标签(Label)没有放入统计次数的字典,添加进去
     52             labelCounts[currentLabel] = 0
     53         labelCounts[currentLabel] += 1                #Label计数
     54     shannonEnt = 0.0                                #经验熵(香农熵)
     55     for key in labelCounts:                            #计算香农熵
     56         prob = float(labelCounts[key]) / numEntires    #选择该标签(Label)的概率
     57         shannonEnt -= prob * log(prob, 2)            #利用公式计算
     58     return shannonEnt                                #返回经验熵(香农熵)
     59 
     60 if __name__ == '__main__':
     61     dataSet, features = createDataSet()
     62     print(dataSet)
     63     print(calcShannonEnt(dataSet))
     64 
     65 """
     66 函数说明:按照给定特征划分数据集
     67 Parameters:
     68     dataSet - 待划分的数据集
     69     axis - 划分数据集的特征
     70     value - 需要返回的特征的值
     71 """
     72 def splitDataSet(dataSet, axis, value):
     73     retDataSet = []                                        #创建返回的数据集列表
     74     for featVec in dataSet:                             #遍历数据集
     75         if featVec[axis] == value:
     76             reducedFeatVec = featVec[:axis]                #去掉axis特征
     77             reducedFeatVec.extend(featVec[axis+1:])     #将符合条件的添加到返回的数据集
     78             retDataSet.append(reducedFeatVec)
     79     return retDataSet                                      #返回划分后的数据集
     80 
     81 """
     82 函数说明:选择最优特征
     83 Parameters:
     84     dataSet - 数据集
     85 Returns:
     86     bestFeature - 信息增益最大的(最优)特征的索引值
     87 """
     88 def chooseBestFeatureToSplit(dataSet):
     89     numFeatures = len(dataSet[0]) - 1                    #特征数量
     90     baseEntropy = calcShannonEnt(dataSet)                 #计算数据集的香农熵
     91     bestInfoGain = 0.0                                  #信息增益
     92     bestFeature = -1                                    #最优特征的索引值
     93     for i in range(numFeatures):                         #遍历所有特征
     94         #获取dataSet的第i个所有特征
     95         featList = [example[i] for example in dataSet]
     96         uniqueVals = set(featList)                         #创建set集合{},元素不可重复
     97         newEntropy = 0.0                                  #经验条件熵
     98         for value in uniqueVals:                         #计算信息增益
     99             subDataSet = splitDataSet(dataSet, i, value)         #subDataSet划分后的子集
    100             prob = len(subDataSet) / float(len(dataSet))           #计算子集的概率
    101             newEntropy += prob * calcShannonEnt(subDataSet)     #根据公式计算经验条件熵
    102         infoGain = baseEntropy - newEntropy                     #信息增益
    103         print("第%d个特征的增益为%.3f" % (i, infoGain))            #打印每个特征的信息增益
    104         if (infoGain > bestInfoGain):                             #计算信息增益
    105             bestInfoGain = infoGain                             #更新信息增益,找到最大的信息增益
    106             bestFeature = i                                     #记录信息增益最大的特征的索引值
    107     return bestFeature                                             #返回信息增益最大的特征的索引值
    108 
    109 '''
    110 函数说明:统计classList中出现此处最多的元素(类标签)
    111 
    112 Parameters:
    113     classList - 类标签列表
    114 Returns:
    115     sortedClassCount[0][0] - 出现此处最多的元素(类标签)
    116 '''
    117 def majorityCnt(classList):
    118     classCount = {}
    119     for vote in classList:           #统计classList中每个元素出现的次数
    120         if vote not in classCount.keys():
    121             classCount[vote] = 0
    122         classCount[vote] += 1
    123     sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)        #根据字典的值降序排序
    124     return sortedClassCount[0][0]    #返回classList中出现次数最多的元素
    125 
    126 '''
    127 函数说明:创建决策树
    128 Parameters:
    129     dataSet - 训练数据集
    130     labels - 分类属性标签
    131     featLabels - 存储选择的最优特征标签
    132 Returns:
    133     myTree - 决策树
    134 '''
    135 def createTree(dataSet, labels, featLabels):
    136     classList = [example[-1] for example in dataSet]   #取分类标签(是否放贷:yes or no)
    137     if classList.count(classList[0]) == len(classList):  #如果类别完全相同则停止继续划分
    138         return classList[0]
    139     if len(dataSet[0]) == 1:             #遍历完所有特征时返回出现次数最多的类标签
    140         return majorityCnt(classList)
    141     bestFeat = chooseBestFeatureToSplit(dataSet)      #选择最优特征
    142     bestFeatLabel = labels[bestFeat]                 #最优特征的标签
    143     featLabels.append(bestFeatLabel)
    144     myTree = {bestFeatLabel: {}}                    #根据最优特征的标签生成树
    145     del(labels[bestFeat])                           #删除已经使用特征标签
    146     featValues = [example[bestFeat] for example in dataSet]   #得到训练集中所有最优特征的属性值
    147     uniqueVals = set(featValues)       #去掉重复的属性值
    148     for value in uniqueVals:           #遍历特征,创建决策树。
    149         myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), labels, featLabels)
    150     return myTree
    151 
    152 '''
    153 getNumLeafs:获取决策树叶子结点的数目
    154     getTreeDepth:获取决策树的层数
    155     plotNode:绘制结点
    156     plotMidText:标注有向边属性值
    157     plotTree:绘制决策树
    158     createPlot:创建绘制面板
    159 
    160 函数说明:获取决策树叶子结点的数目
    161 Parameters:
    162     myTree - 决策树
    163 Returns:
    164     numLeafs - 决策树的叶子结点的数目
    165 '''
    166 def getNumLeafs(myTree):
    167     numLeafs = 0                   #初始化叶子
    168     firstStr = next(iter(myTree))    #python3中myTree.keys()返回的是dict_keys,不在是list,所以不能使用myTree.keys()[0]的方法获取结点属性,可以使用list(myTree.keys())[0]
    169     secondDict = myTree[firstStr]      #获取下一组字典
    170     for key in secondDict.keys():
    171         if type(secondDict[key]).__name__=='dict':      #测试该结点是否为字典,如果不是字典,代表此结点为叶子结点
    172             numLeafs += getNumLeafs(secondDict[key])
    173         else:   numLeafs +=1
    174     return numLeafs
    175 
    176 '''
    177 函数说明:获取决策树的层数
    178 Parameters:
    179     myTree - 决策树
    180 Returns:
    181     maxDepth - 决策树的层数
    182 '''
    183 def getTreeDepth(myTree):
    184     maxDepth = 0                       #初始化决策树深度
    185     firstStr = next(iter(myTree))      #python3中myTree.keys()返回的是dict_keys,不在是list,所以不能使用myTree.keys()[0]的方法获取结点属性,可以使用list(myTree.keys())[0]
    186     secondDict = myTree[firstStr]      #获取下一个字典
    187     for key in secondDict.keys():
    188         if type(secondDict[key]).__name__=='dict':   #测试该结点是否为字典,如果不是字典,代表此结点为叶子结点
    189             thisDepth = 1 + getTreeDepth(secondDict[key])
    190         else:   thisDepth = 1
    191         if thisDepth > maxDepth: maxDepth = thisDepth            #更新层数
    192     return maxDepth
    193 
    194 '''
    195 函数说明:绘制结点
    196 Parameters:
    197     nodeTxt - 结点名
    198     centerPt - 文本位置
    199     parentPt - 标注的箭头位置
    200     nodeType - 结点格式
    201 Returns:
    202 203 '''
    204 def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    205     arrow_args = dict(arrowstyle="<-")       #定义箭头格式
    206     font = FontProperties(fname=r"c:windowsfontssimsun.ttc", size=14)     #设置中文字体
    207     createPlot.ax1.annotate(nodeTxt, xy=parentPt,  xycoords='axes fraction',    #绘制结点
    208         xytext=centerPt, textcoords='axes fraction',
    209         va="center", ha="center", bbox=nodeType, arrowprops=arrow_args, FontProperties=font)
    210 
    211 '''
    212 函数说明:标注有向边属性值
    213 Parameters:
    214     cntrPt、parentPt - 用于计算标注位置
    215     txtString - 标注的内容
    216 Returns:
    217 218 '''
    219 def plotMidText(cntrPt, parentPt, txtString):
    220     xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]                                            #计算标注位置
    221     yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
    222     createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)
    223 
    224 '''
    225 函数说明:绘制决策树
    226 Parameters:
    227     myTree - 决策树(字典)
    228     parentPt - 标注的内容
    229     nodeTxt - 结点名
    230 Returns:
    231 232 '''
    233 def plotTree(myTree, parentPt, nodeTxt):
    234     decisionNode = dict(boxstyle="sawtooth", fc="0.8")                                        #设置结点格式
    235     leafNode = dict(boxstyle="round4", fc="0.8")                                            #设置叶结点格式
    236     numLeafs = getNumLeafs(myTree)                                                          #获取决策树叶结点数目,决定了树的宽度
    237     depth = getTreeDepth(myTree)                                                            #获取决策树层数
    238     firstStr = next(iter(myTree))                                                            #下个字典
    239     cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)    #中心位置
    240     plotMidText(cntrPt, parentPt, nodeTxt)                                                    #标注有向边属性值
    241     plotNode(firstStr, cntrPt, parentPt, decisionNode)                                        #绘制结点
    242     secondDict = myTree[firstStr]                                                            #下一个字典,也就是继续绘制子结点
    243     plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD                                        #y偏移
    244     for key in secondDict.keys():
    245         if type(secondDict[key]).__name__=='dict':                                            #测试该结点是否为字典,如果不是字典,代表此结点为叶子结点
    246             plotTree(secondDict[key],cntrPt,str(key))                                        #不是叶结点,递归调用继续绘制
    247         else:                                                                                #如果是叶结点,绘制叶结点,并标注有向边属性值
    248             plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
    249             plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
    250             plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    251     plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
    252 
    253 '''
    254 函数说明:创建绘制面板
    255 Parameters:
    256     inTree - 决策树(字典)
    257 Returns:
    258 259 '''
    260 def createPlot(inTree):
    261     fig = plt.figure(1, facecolor='white')                                                    #创建fig
    262     fig.clf()                                                                                #清空fig
    263     axprops = dict(xticks=[], yticks=[])
    264     createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)                                #去掉x、y轴
    265     plotTree.totalW = float(getNumLeafs(inTree))                                            #获取决策树叶结点数目
    266     plotTree.totalD = float(getTreeDepth(inTree))                                            #获取决策树层数
    267     plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;                                #x偏移
    268     plotTree(inTree, (0.5,1.0), '')                                                            #绘制决策树
    269     plt.show()                                                                                 #显示绘制结果
    270 
    271 def classify(inputTree, featLabels, testVec):
    272     firstStr = next(iter(inputTree))                                                        #获取决策树结点
    273     secondDict = inputTree[firstStr]                                                        #下一个字典
    274     featIndex = featLabels.index(firstStr)
    275     for key in secondDict.keys():
    276         if testVec[featIndex] == key:
    277             if type(secondDict[key]).__name__ == 'dict':
    278                 classLabel = classify(secondDict[key], featLabels, testVec)
    279             else: classLabel = secondDict[key]
    280     return classLabel
    281 
    282 if __name__ == '__main__':
    283     dataSet, labels = createDataSet()
    284     featLabels = []
    285     myTree = createTree(dataSet, labels, featLabels)
    286     testVec = [0,1]                                        #测试数据
    287     result = classify(myTree, featLabels, testVec)
    288     if result == 'yes':
    289         print('放贷')
    290     if result == 'no':
    291         print('不放贷')
    292 '''
    293 函数说明: 存储决策树
    294 Parameters:
    295     inputTree - 已经生成的决策树
    296     filename - 决策树的存储文件名
    297     Returns:
    298 299 '''
    300 def storeTree(inputTree, filename):
    301         fw = open(filename, 'wb')
    302         pickle.dump(inputTree, fw)
    303         fw.close()
    304 
    305 if __name__ == '__main__':
    306     myTree = {'有自己的房子': {0: {'有工作': {0: 'no', 1: 'yes'}}, 1: 'yes'}}
    307     storeTree(myTree, 'classifierStorage.txt')
    308 
    309 '''
    310 函数说明:读取决策树
    311 Parameters:
    312     filename - 决策树的存储文件名
    313 Returns:
    314     pickle.load(fr) - 决策树字典
    315 '''
    316 def grabTree(filename):
    317     fr = open(filename, 'rb')
    318     return pickle.load(fr)
    319 
    320 if __name__ == '__main__':
    321     myTree = grabTree('classifierStorage.txt')
    322     print(myTree)
    323 
    324 # if __name__ == '__main__':
    325 #     dataSet, labels = createDataSet()
    326 #     featLabels = []
    327 #     myTree = createTree(dataSet, labels, featLabels)
    328 #     print(myTree)
    329 #     createPlot(myTree)
    330 
    331 
    332 # if __name__ == '__main__':
    333 #     dataSet, features = createDataSet()
    334 #     print("最优特征索引值:" + str(chooseBestFeatureToSplit(dataSet)))
  • 相关阅读:
    学生成绩判定系统
    A@2a139a55 结果产生的原因
    为什么子类的构造方法在运行之前,必须调用父类的构造方法?能不能反过来?为什么不能反过来?
    父类与子类之间构造方法的调用关系
    阅读《大道至简》第六章有感
    大数加法 待完善
    BigInteger大数加法源代码及分析
    随机数组求和
    读《大道至简》第五章“失败的过程也是过程 ”有感
    学习进度第15周
  • 原文地址:https://www.cnblogs.com/fd-682012/p/11593724.html
Copyright © 2011-2022 走看看