zoukankan      html  css  js  c++  java
  • 机器学习实战之决策树(二)

      1 # -*- coding: UTF-8 -*-
      2 from math import log
      3 import operator
      4 
      5 """
      6 函数说明:计算给定数据集的经验熵(香农熵)
      7 
      8 Parameters:
      9     dataSet - 数据集
     10 Returns:
     11     shannonEnt - 经验熵(香农熵)
     12 Author:
     13     Jack Cui
     14 Blog:
     15     http://blog.csdn.net/c406495762
     16 Modify:
     17     2017-07-24
     18 """
     19 def calcShannonEnt(dataSet):
     20     numEntires = len(dataSet)                        #返回数据集的行数
     21     labelCounts = {}                                #保存每个标签(Label)出现次数的字典
     22     for featVec in dataSet:                            #对每组特征向量进行统计
     23         currentLabel = featVec[-1]                    #提取标签(Label)信息
     24         if currentLabel not in labelCounts.keys():    #如果标签(Label)没有放入统计次数的字典,添加进去
     25             labelCounts[currentLabel] = 0
     26         labelCounts[currentLabel] += 1                #Label计数
     27     shannonEnt = 0.0                                #经验熵(香农熵)
     28     for key in labelCounts:                            #计算香农熵
     29         prob = float(labelCounts[key]) / numEntires    #选择该标签(Label)的概率
     30         shannonEnt -= prob * log(prob, 2)            #利用公式计算
     31     return shannonEnt                                #返回经验熵(香农熵)
     32 
     33 """
     34 函数说明:创建测试数据集
     35 
     36 Parameters:
     37  38 Returns:
     39     dataSet - 数据集
     40     labels - 特征标签
     41 Author:
     42     Jack Cui
     43 Blog:
     44     http://blog.csdn.net/c406495762
     45 Modify:
     46     2017-07-20
     47 """
     48 def createDataSet():
     49     dataSet = [[0, 0, 0, 0, 'no'],                        #数据集
     50             [0, 0, 0, 1, 'no'],
     51             [0, 1, 0, 1, 'yes'],
     52             [0, 1, 1, 0, 'yes'],
     53             [0, 0, 0, 0, 'no'],
     54             [1, 0, 0, 0, 'no'],
     55             [1, 0, 0, 1, 'no'],
     56             [1, 1, 1, 1, 'yes'],
     57             [1, 0, 1, 2, 'yes'],
     58             [1, 0, 1, 2, 'yes'],
     59             [2, 0, 1, 2, 'yes'],
     60             [2, 0, 1, 1, 'yes'],
     61             [2, 1, 0, 1, 'yes'],
     62             [2, 1, 0, 2, 'yes'],
     63             [2, 0, 0, 0, 'no']]
     64     labels = ['年龄', '有工作', '有自己的房子', '信贷情况']        #特征标签
     65     return dataSet, labels                             #返回数据集和分类属性
     66 
     67 """
     68 函数说明:按照给定特征划分数据集
     69 
     70 Parameters:
     71     dataSet - 待划分的数据集
     72     axis - 划分数据集的特征
     73     value - 需要返回的特征的值
     74 Returns:
     75  76 Author:
     77     Jack Cui
     78 Blog:
     79     http://blog.csdn.net/c406495762
     80 Modify:
     81     2017-07-24
     82 """
     83 #value:0,1,2
     84 #axis:4个特征
     85 """
     86 eg:axis =0,value=1(找出是中年的,把第0列(年龄)去掉)构成的数据再计算标签那一列的信息熵
     87 [ 0, 0, 0, 'no'],
     88 [0, 0, 1, 'no'],
     89 [1, 1, 1, 'yes'],
     90 [0, 1, 2, 'yes'],
     91 [0, 1, 2, 'yes'],
     92 """
     93 def splitDataSet(dataSet, axis, value):
     94     retDataSet = []                                        #创建返回的数据集列表
     95     for featVec in dataSet:                             #遍历数据集
     96         if featVec[axis] == value:
     97             reducedFeatVec = featVec[:axis]                #去掉axis特征
     98             reducedFeatVec.extend(featVec[axis+1:])     #将符合条件的添加到返回的数据集
     99             retDataSet.append(reducedFeatVec)
    100     return retDataSet                                      #返回划分后的数据集
    101 
    102 """
    103 函数说明:选择最优特征
    104 
    105 Parameters:
    106     dataSet - 数据集
    107 Returns:
    108     bestFeature - 信息增益最大的(最优)特征的索引值
    109 Author:
    110     Jack Cui
    111 Blog:
    112     http://blog.csdn.net/c406495762
    113 Modify:
    114     2017-07-20
    115 """
    116 def chooseBestFeatureToSplit(dataSet):
    117     numFeatures = len(dataSet[0]) - 1                    #特征数量为4:len(dataSet[0]):矩阵第一行的长度,最后一列为标签
    118     baseEntropy = calcShannonEnt(dataSet)                 #计算数据集的香农熵
    119     print('香农熵')
    120     print(baseEntropy)
    121     bestInfoGain = 0.0                                  #信息增益
    122     bestFeature = -1                                    #最优特征的索引值
    123     for i in range(numFeatures):                         #遍历所有特征4个特征
    124         #获取dataSet的第i个所有特征
    125         featList = [example[i] for example in dataSet]  #取数据的每一行,再取第i个特征,将第i个特征的值放入一个列表里面
    126         uniqueVals = set(featList)                         #创建set集合{},元素不可重复
    127         newEntropy = 0.0                                  #经验条件熵
    128         for value in uniqueVals:                         #计算信息增益
    129             subDataSet = splitDataSet(dataSet, i, value)         #subDataSet划分后的子集
    130             prob = len(subDataSet) / float(len(dataSet))           #计算子集的概率
    131             newEntropy += prob * calcShannonEnt(subDataSet)     #根据公式计算经验条件熵,只计算子集的信息熵
    132         infoGain = baseEntropy - newEntropy                     #信息增益
    133         # print("第%d个特征的增益为%.3f" % (i, infoGain))            #打印每个特征的信息增益
    134         if (infoGain > bestInfoGain):                             #计算信息增益
    135             bestInfoGain = infoGain                             #更新信息增益,找到最大的信息增益
    136             bestFeature = i                                     #记录信息增益最大的特征的索引值
    137     return bestFeature                                             #返回信息增益最大的特征的索引值
    138 
    139 
    140 """
    141 函数说明:统计classList中出现此处最多的元素(类标签)
    142 
    143 Parameters:
    144     classList - 类标签列表
    145 Returns:
    146     sortedClassCount[0][0] - 出现此处最多的元素(类标签)
    147 Author:
    148     Jack Cui
    149 Blog:
    150     http://blog.csdn.net/c406495762
    151 Modify:
    152     2017-07-24
    153 """
    154 def majorityCnt(classList):
    155     classCount = {}
    156     for vote in classList:                                        #统计classList中每个元素出现的次数
    157         if vote not in classCount.keys():classCount[vote] = 0
    158         classCount[vote] += 1
    159     sortedClassCount = sorted(classCount.items(), key = operator.itemgetter(1), reverse = True)        #根据字典的值降序排序
    160     return sortedClassCount[0][0]                                #返回classList中出现次数最多的元素
    161 
    162 """
    163 函数说明:创建决策树
    164 
    165 Parameters:
    166     dataSet - 训练数据集
    167     labels - 分类属性标签
    168     featLabels - 存储选择的最优特征标签
    169 Returns:
    170     myTree - 决策树
    171 Author:
    172     Jack Cui
    173 Blog:
    174     http://blog.csdn.net/c406495762
    175 Modify:
    176     2017-07-25
    177 """
    178 def createTree(dataSet, labels, featLabels):
    179     classList = [example[-1] for example in dataSet]            #取分类标签(是否放贷:yes or no)
    180     if classList.count(classList[0]) == len(classList):            #如果类别完全相同则停止继续划分
    181         return classList[0]
    182     if len(dataSet[0]) == 1:                                    #遍历完所有特征时返回出现次数最多的类标签
    183         return majorityCnt(classList)  #分到最后没有数据了,但还有特征,使用投票表决法
    184     bestFeat = chooseBestFeatureToSplit(dataSet)                #选择最优特征
    185     bestFeatLabel = labels[bestFeat]                            #最优特征的标签
    186     featLabels.append(bestFeatLabel)
    187     myTree = {bestFeatLabel:{}}                                    #根据最优特征的标签生成树
    188     del(labels[bestFeat]) #   bestFeat为特征索引                   #删除已经使用特征标签
    189     featValues = [example[bestFeat] for example in dataSet]        #得到训练集中所有最优特征的属性值
    190     #最优特征值那一列有几种情况
    191     uniqueVals = set(featValues)                                #去掉重复的属性值
    192     for value in uniqueVals:                                    #遍历特征,创建决策树。
    193         myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), labels, featLabels)
    194     return myTree
    195 
    196 if __name__ == '__main__':
    197     dataSet, labels = createDataSet()
    198     print('最优索引值'+str(chooseBestFeatureToSplit(dataSet)))
    199     featLabels = []
    200     myTree = createTree(dataSet, labels, featLabels)
    201     print(myTree)

    创建分支(createBranch)伪代码:

    检测数据集中的每个子项是否属于同一类

      if so return 类标签

      Else 

        寻找划分数据集的最好特征

        划分数据集

        创建分支节点

          for  每个划分的子集

            调用函数createBranch并增加返回结果到分支节点中

        return 分支节点

  • 相关阅读:
    2020面向对象程序设计寒假作业2 题解
    题解 P3372 【【模板】线段树 1】
    Global variant VS local variant
    u2u
    深入浅出PowerShell系列
    深入浅出WF系列
    debug
    深入浅出SharePoint系列
    InfoPath debug
    深入浅出Nintex系列
  • 原文地址:https://www.cnblogs.com/shuangcao/p/11374104.html
Copyright © 2011-2022 走看看