zoukankan      html  css  js  c++  java
  • machine learning in action --decision tree

    一、知道如何计算熵

       wiki对熵的解释以及熵的计算

     https://zh.wikipedia.org/wiki/%E7%86%B5_(%E4%BF%A1%E6%81%AF%E8%AE%BA)

    综合来说,熵实际是对随机变量的比特量和顺次发生概率相乘再总和的数学期望

    原理:

           在构造决策树时, 我们需要解决的第一个问题就是, 当前数据集上哪个特征在划分数据分类
    时起决定性作用。为了找到决定性的特征,划分出最好的结果,我们必须评估每个特征。完成测
    试之后, 原始数据集就被划分为几个数据子集。这些数据子集会分布在第一个决策点的所有分支

    上。 如果某个分支下的数据属于同一类型, 则当前无需阅读的垃圾邮件已经正确地划分数据分类,
    无需进一步对数据集进行分割。如果数据子集内的数据不属于同一类型, 则需要重复划分数据子
    集的过程。如何划分数据子集的算法和划分原始数据集的方法相同, 直到所有具有相同类型的数
    据均在一个数据子集内。

    二、决策树代码解析

    1.计算熵的函数

     1 def calcShannonEnt(dataSet):
     2     numEntries = len(dataSet)                       #[[1,2,3],[4,5,6],[7,8,9]] 的len是3
     3     labelCounts = {}                                #创建一个labelCounts字典
     4     for featVec in dataSet:
     5         currentLabel = featVec[-1]
     6         if currentLabel not in labelCounts.keys():
     7             labelCounts[currentLabel] = 0
     8         labelCounts[currentLabel] += 1             #for 最后计算出了数据集中所有labels(dataset中的最后一列)出现的次数
     9     shannonEnt = 0.0
    10     for key in labelCounts:                        #参考熵的计算公式
    11         prob = float(labelCounts[key])/numEntries
    12         shannonEnt -= prob * log(prob,2)
    13     return shannonEnt

    2.创建数据集,这个很简单没什么好解释

    1 def createDataSet():
    2     dataSet = [[1, 1, 'yes'],
    3                [1, 1, 'yes'],
    4                [1, 0, 'no'],
    5                [0, 1, 'no'],
    6                [0, 1, 'no']]
    7     labels = ['no surfacing','flippers']
    8     return dataSet, labels

    3.分割数据集

    1 def splitDataSet(dataSet, axis, value):                   #调用的时候:splitDataSet(dataSet,0,1)表示按索引为0,其value==1的特征来分割
    2     retDataSet = []
    3     for featVec in dataSet:                               #遍历dataSet中的每一行,如果这一行axis列的值确实等于value,则把这一行除axis列之外的数据都放在
    4         if featVec[axis] == value:                        #reducedFeatVec中,再append到retDataSet
    5             reducedFeatVec = featVec[:axis]
    6             reducedFeatVec.extend(featVec[axis+1:])
    7             retDataSet.append(reducedFeatVec)
    8     return retDataSet

    因此,如果调用splitDataSet(dataSet,0,1)则返回新的数据集为

    [[1, 'yes'],
     [ 1, 'yes'],
     [ 0, 'no']]

    4.选择最好的分割方式

    熵计算将会告诉我们如何划分数据集是最好的数据组织方式,原则就是熵越大越好

     1 def chooseBestFeatureToSplit(dataSet):
     2     numFeatures = len(dataSet[0]) - 1                          #算出一共有多少个特征
     3     baseEntropy = calcShannonEnt(dataSet)                      #计算整个数据集的香农熵
     4     bestInfoGain = 0.0; bestFeature = -1
     5     for i in range(numFeatures):
     6         featList = [example[i] for example in dataSet]         #表示取dataSet中的第i列数据
     7         uniqueVals = set(featList)                             #取这一列的唯一标签[1,1,1,0,0] ----->[0,1]
     8         newEntropy = 0.0
     9         for value in uniqueVals:
    10             subDataSet = splitDataSet(dataSet, i, value)       #调到了分割函数,注意i和value分别是什么
    11             prob = len(subDataSet)/float(len(dataSet))
    12             newEntropy += prob * calcShannonEnt(subDataSet)
    13         infoGain = baseEntropy - newEntropy
    14         if (infoGain > bestInfoGain):
    15             bestInfoGain = infoGain
    16             bestFeature = i
    17     return bestFeature

    代码13到16行,一眼看过去,是在选择划分数据集最好的特征,但是一开始没理解为什么要用数据集原始熵减去划分后的子数集的熵,且这个结果越大说明划分越好,

    后来看书里提到,熵可以理解为表示数据的无序程度,数据越无序,则熵越大,因此,划分后的子数据集的熵越小越小,越小说明这个子数据集里的数据无序度越小,反过来就是说

    里面的数据更大程度上属于一类

    二、创建树

    1.选举代码 ,选出分类list里出现频率最高的

    1     def majorityCnt(classList):
    2         classCount={}
    3         for vote in classList:
    4             if vote not in classCount.keys():classCount[vote]=0
    5             classCount[vote] +=1
    6         sortedClassCount = sorted(classCount.iteritems(),
    7           key=operator.itemgetter(1), reverse=True)
    8         return sortedClassCount[0][0]

    2.创建树的函数

     1 def createTree(dataSet,labels):                           #注意这是一个递归函数,labels可理解为每一列特征的标签,或每一列特征的名字
     2     classList = [example[-1] for example in dataSet]
     3     if classList.count(classList[0]) == len(classList):   #如果当前数据集里的数据都属一类,则不必再划分,注意返回值是当前的标签类别
     4         return classList[0]
     5     if len(dataSet[0]) == 1:                              #这表示数据集只剩下最后一列即标签列,特征列都在被分割的过程中消耗光了,这时就要用选举法来决定类别
     6         return majorityCnt(classList)
     7     bestFeat = chooseBestFeatureToSplit(dataSet)          #选取最佳的特征值来进行分割
     8     bestFeatLabel = labels[bestFeat]
     9     myTree = {bestFeatLabel:{}}                           #开始构建myTree,可以看到myTree实际上是一个嵌套的字典
    10     del(labels[bestFeat])                                 #这里开始消耗掉这个用来划分当前数据集的特征
    11     featValues = [example[bestFeat] for example in dataSet]  
    12     uniqueVals = set(featValues)
    13     for value in uniqueVals:
    14         subLabels = labels[:]
    15         myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels)
    17     return myTree

    11-17行是核心,11-12行找出当前最佳特征列,然后创建这一列唯一的值列表,再对每一个值进行子树的划分

    最后的执行结果:可以看出树的结果是是嵌套的字典

    1 >>> myTree
    2 {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
  • 相关阅读:
    html图片链接不显示图片
    Mybatis总结
    IllegalArgumentException: Could not resolve resource location pattern [classpath .xml]: class path resource cannot be resolved to URL because it does not exist
    java.lang.ClassNotFoundException: com.radiadesign.catalina.session.RedisSessionHandlerValve
    sqlserver2008客户端设置主键自增
    判断手机还是电脑访问
    SSM与jsp传递实体类
    ssm打印sql语句
    SqlSession 同步为注册,因为同步未激活
    for循环取出每个i的值
  • 原文地址:https://www.cnblogs.com/zhengchunhao/p/5604672.html
Copyright © 2011-2022 走看看