zoukankan      html  css  js  c++  java
  • 决策树分类原理

    上一篇博客我们看了一个决策树分类的例子,但是我们没有深入决策树分类的内部原理。

    这节我们讨论的决策树分类的所有特征的特征值都是离散的,明白了离散特征值如何分类的原理,连续值的也不难理解。

    决策树分类的核心在于确定那一个特征的那一个特征值分类最有效,可能不同的场景,每个人采用的衡量方法也不一样,这里我们采用香农熵。

    下面我们看一下简单的例子

    五个样例,两个特征(是否浮上水面,是否有鳍),判断该动物是否是水生(类别)

    def createDataSet():
        dataSet = [[1, 1, 'yes'],
                   [1, 1, 'yes'],
                   [1, 0, 'no'],
                   [0, 1, 'no'],
                   [0, 1, 'no']]
        labels = ['no surfacing','flippers']
        return dataSet, labels

    计算特征划分的香农熵:传入一个数据集,根据类别计算香农熵(每一个类别的概率p*log(p,2)的和*-1)

    def calcShannonEnt(dataSet):
        #样本个数
        numEntries = len(dataSet)
        labelCounts = {}
        for featVec in dataSet: #
            currentLabel = featVec[-1]
            if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0
            labelCounts[currentLabel] += 1
        shannonEnt = 0.0
        for key in labelCounts:
            prob = float(labelCounts[key])/numEntries
            shannonEnt -= prob * log(prob,2) #log base 2
        return shannonEnt

    划分数据集:根据传入的特征(对应数据中的那一列)和特征值对数据进行划分(把传入特征对应的列删掉,把特征列等于传入特征值的行删掉),返回根据该特征的特征值划分的结果数据集

    def splitDataSet(dataSet, axis, value):
        retDataSet = []
        for featVec in dataSet:
            if featVec[axis] == value:
                reducedFeatVec = featVec[:axis]     #chop out axis used for splitting
                reducedFeatVec.extend(featVec[axis+1:])
                retDataSet.append(reducedFeatVec)
        return retDataSet

    寻找最佳特征的特征值:遍历所有特征的特征值,寻找最好的结果

    def chooseBestFeatureToSplit(dataSet):
        numFeatures = len(dataSet[0]) - 1      #
        baseEntropy = calcShannonEnt(dataSet)
        bestInfoGain = 0.0; bestFeature = -1
        for i in range(numFeatures):        #遍历特征
            featList = [example[i] for example in dataSet]#得到特征列
            uniqueVals = set(featList)       #从特征列获取该特征的特征值的set集合
            newEntropy = 0.0
            for value in uniqueVals:
                subDataSet = splitDataSet(dataSet, i, value)
                prob = len(subDataSet)/float(len(dataSet))
                newEntropy += prob * calcShannonEnt(subDataSet)     
            infoGain = baseEntropy - newEntropy     #计算当前划分的香农熵
            if (infoGain > bestInfoGain):       #比较是否是最好的结果
                bestInfoGain = infoGain         #记录最好的结果和最好的特征
                bestFeature = i
        return bestFeature                      #返回最好特征的对应的列下标

    找到最好的特征,然后以最好的特征作为根节点,所有的特征值作为分支,继续划分数据集,直到每个小数据集里面的类别都一样或者没有可以继续划分的特征的

    存在一种情况,已经没有可以划分的特征了,但是这个数据集的类别有多个,怎么办?按照前一篇博客说的,多数表决的方法决定改分支的类别:

    下面的函数接受一个类别的列表,返回类别数多的类别

    def majorityCnt(classList):
        classCount={}
        for vote in classList:
            if vote not in classCount.keys(): classCount[vote] = 0
            classCount[vote] += 1
        sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
        return sortedClassCount[0][0]

    知道了怎么划分数据集,也知道划分停止的条件,下面我们可以生成决策树了

    def createTree(dataSet,labels):
        classList = [example[-1] for example in dataSet]
        if classList.count(classList[0]) == len(classList): 
            return classList[0]#所有的类别都一样,就不用再划分了
        if len(dataSet[0]) == 1: #如果没有继续可以划分的特征,就多数表决决定分支的类别
            return majorityCnt(classList)
        bestFeat = chooseBestFeatureToSplit(dataSet)
        bestFeatLabel = labels[bestFeat]
        myTree = {bestFeatLabel:{}}
        del(labels[bestFeat])
        featValues = [example[bestFeat] for example in dataSet]
        uniqueVals = set(featValues)
        for value in uniqueVals:
            subLabels = labels[:]       #copy all of labels, so trees don't mess up existing labels
            myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels)
        return myTree  

    下面是生成的决策树:

    {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}

    画成图示如下:

    希望到这里能理解理解决策树分类的原理了

  • 相关阅读:
    mysql timestamp字段定义的
    mybatis的Selective接口和普通接口的区别
    intllij IDE 中git ignore 无法删除target目录下的文件
    maven的单元测试中没有
    java volatile关键字
    RestExpress response中addHeader 导致stackOverflow
    log4j配置后行号乱码显示为?问号
    软件研发人员的职业发展规划
    CPU与内存互联的架构演变
    windows系统安装
  • 原文地址:https://www.cnblogs.com/qwj-sysu/p/5970066.html
Copyright © 2011-2022 走看看