zoukankan      html  css  js  c++  java
  • 机器学习实战-决策树

    这是本文所用的数据集

    海洋生物数据

      不浮出水面是否可以生存 是否有脚踝 属于鱼类
    1
    2
    3
    4
    5

    1.思想

      决策树是机器学习里面比较常见的一种算法。决策树它是这样工作的:给你一个海洋生物的数据集,那么我怎么来判断这个样本是否属于鱼类?我们常规的是不是首先观察它是否在水面上能够生存,如果不能,那么根据常识我们知道它不是海洋生物。如果能够生存,那么我们接下来又回去观察它是否有脚踝,如果有,我们判断它为海洋生物,如果没有,我们判断它不是海洋生物。 简单的说,上面这个判断的过程就是决策的过程!

      这里,我们就有另外一个问题了,我们为何选择首先观察它是否在水面上能够生存,然后再观察它是否有脚踝呢? 我们这篇文章采用ID3算法,那么这就涉及到信息增益的问题了。关于这个问题理解以及公式推导,我们可以参考这个博客。简单的来说,就是我们选择这个特征能够让我们样本集合获得的“纯度提升”越大。

    2.伪代码

      训练集D={(x1,y1),(x2,y2),...,(xm,ym)}

      属性集A={a1,a2,...,ad}

        TreeGenerate(D,A)

        1.生成节点node

        2.if D中所有样本属于同一类别C then

        3.  将node标记为类别为C的叶子节点;   return

        4.else if A=Ø 或者 D中样本在A上取值相同 then

        5.  将node标记为叶节点,其类别标记为D中样本数最多的类;  return

        6.从A中选取最优属性ak

        7.for ak中的每一个值aki:

        8.  为node生成一个分支;令Di表示D在属性ak上取值为aki的样本子集

        9.  if Di = Ø then

      10.    将分支节点标记为叶节点,类别标记为D中样本最多的类;  return

      11.   else

      12.    以TreeGenerate(Di,A{ak})为分支节点递归创建决策树

    3.代码实现

    import numpy as np
    from math import log
    
    #创建数据集
    def createDataSet():
        # data =    [[0, 0, 0, 0, 'no'],  # 数据集
        #            [0, 0, 0, 1, 'no'],
        #            [0, 1, 0, 1, 'yes'],
        #            [0, 1, 1, 0, 'yes'],
        #            [0, 0, 0, 0, 'no'],
        #            [1, 0, 0, 0, 'no'],
        #            [1, 0, 0, 1, 'no'],
        #            [1, 1, 1, 1, 'yes'],
        #            [1, 0, 1, 2, 'yes'],
        #            [1, 0, 1, 2, 'yes'],
        #            [2, 0, 1, 2, 'yes'],
        #            [2, 0, 1, 1, 'yes'],
        #            [2, 1, 0, 1, 'yes'],
        #            [2, 1, 0, 2, 'yes'],
        #            [2, 0, 0, 0, 'no']]
        # labels = ['年龄', '有工作', '有自己的房子', '信贷情况']
        data = [[1,1,'yes'],
                [1,1,'yes'],
                [1,0,'no'],
                [0,1,'no'],
                [0,1,'no']]
        labels = ['no surfacing','flippers']
        return data,labels
    
    #计算香农熵
    def calEnt(dataSet):
        labelsCount ={}
        num = len(dataSet)
        for featVec in dataSet:
            currentLabel = featVec[-1]
            if currentLabel  not in labelsCount.keys():
                labelsCount[currentLabel]=1
            else:
                labelsCount[currentLabel]+=1
        prob = 0.0
        for key in labelsCount:
            p = float(labelsCount[key]) / num
            prob -= p * log(p,2)
        return prob
    #得到相应子集
    def splitDataSet(dataSet,axis,value):   #axis=n 则表示取第n个特征列,且特征取值为value的子数据集
        subDataSet = []
        for data in dataSet:
            if data[axis] == value:
                reData = data[:axis]
                reData.extend(data[axis+1:])
                subDataSet.append(reData)
        return subDataSet
    
    #得到最佳的划分特征
    def getbestFeat(dataSet):
        num_features = len(dataSet[0]) - 1  # 特征数2
        num = len(dataSet)                  # 样本数
        baseInfoGain = 0.0
        for feature in range(num_features):
            #得到该特征有几个属性值
            feature_data = [example[feature] for example in dataSet]
            feature_property = set(feature_data)
            labelsCount = {}
            newEntropy = 0.0
            for label in feature_data:
                if label not in labelsCount:
                    labelsCount[label] = 0
                labelsCount[label] += 1
            for property in feature_property: #属性值 0 1
                subSet = splitDataSet(dataSet, feature, property)
                prob = float(labelsCount[property]) / num
                newEntropy = newEntropy + prob*calEnt(subSet)
            InfoGain = calEnt(dataSet) - newEntropy
            #print('第',feature,'个特征的增益为:',InfoGain)
            if InfoGain>baseInfoGain:
                baseInfoGain = InfoGain
                bestFeat = feature
        return bestFeat
    
    
    #投票
    def majority(classList):
        classCount = {}
        for vote in classList:
            if vote not in classCount:
                classCount[vote] = 0
            classCount[vote] += 1
        cla = sorted(classCount.items(),key = lambda x:x[1],reverse=True)
        return cla[0][0]
    
    #创建决策树
    def createDecisionTree(dataSet,labels):
        classList = [example[-1] for example in dataSet]
        classListSet = set(classList)
        if len(classListSet) == 1:
            return classList[0]
        if len(dataSet[0]) == 1:
            return majority(classList)
        bestFeat = getbestFeat(dataSet)
        print(bestFeat)
        bestLabel = labels[bestFeat]
        del(labels[bestFeat])
        mytree = {bestLabel:{}}
        uniqueProperty = {}
        for property in [example[bestFeat] for example in dataSet]:
            if property not in uniqueProperty:
                uniqueProperty[property] = 0
            uniqueProperty[property] += 1
        for value in uniqueProperty.keys():
            subLabels = labels[:]
            subSet = splitDataSet(dataSet,bestFeat,value)
            mytree[bestLabel][value] = createDecisionTree(subSet,subLabels)
        return mytree
    
    if __name__ =='__main__':
        data,labels = createDataSet()
        result = calEnt(data)
        print(result)  #0.970950594454668
    
        # print(majority([1,0,1,0,0,0]))
        # getbestFeat(data)
        print(createDecisionTree(data,labels))   #{'no surfacing': {1: {'flippers': {1: 'yes', 0: 'no'}}, 0: 'no'}}


    ###################################
    #获取叶子节点的数目
    def getNumLeafs(myTree):
        numLeafs = 0
        firstStr = list(myTree.keys())[0]
        secondDict = myTree[firstStr]
        for i in secondDict.keys():
            if type(secondDict[i]).__name__ == 'dict':
                numLeafs += getNumLeafs(secondDict[i])
            else:
                numLeafs += 1
        return numLeafs
    
    #获取树的层数
    def getTreeDepth(myTree):
        maxDepth = 0
        firstStr = next(iter(myTree))
        secondDict = myTree[firstStr]
        for i in secondDict.keys():
            if type(secondDict[i]).__name__ == 'dict':
                thisDepth = 1 + getTreeDepth(secondDict[i])
            else:
                thisDepth = 1
            if thisDepth>maxDepth:
                maxDepth = thisDepth
        return maxDepth
  • 相关阅读:
    acme.sh 申请let's encrypt证书
    Excel 函数
    mysql索引失效的情况
    mysql之EXPLAIN优化分析
    mysql索引
    mysql视图
    mysql数据类型
    mysql约束
    mysql库和表的管理
    mysql的DML语言(增删改)
  • 原文地址:https://www.cnblogs.com/logo-88/p/10125875.html
Copyright © 2011-2022 走看看