zoukankan      html  css  js  c++  java
  • [置顶] ID3算法的python实现

    这篇文章的内容接着http://blog.csdn.net/xueyunf/article/details/9214727的内容,所有还有部分函数在http://blog.csdn.net/xueyunf/article/details/9212827中,由于这个算法需要理解的内容比较多,所以我分成了3篇分别介绍,因为自己也是用了3天的时间才理解了这一经典算法。当然很犀利的童鞋也许很短时间就理解了这一算法,那么这篇文章也就不适合你了,可以跳过了,读了后不会有太多收获的。

    下面我就贴出代码来,为初学者提示一点东西:

    def majorityCnt(classList):
        classCount ={}
        for vote in classList:
            if vote not in classCount.keys():
                classCount[vote]=0
            classCount[vote]=1
        sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True) 
        return sortedClassCount[0][0]
     
    
    def createTree(dataSet, labels):
        classList = [example[-1] for example in dataSet]
        if classList.count(classList[0])==len(classList):
            return classList[0]
        if len(dataSet[0])==1:
            return majorityCnt(classList)
        bestFeat = chooseBestFeatureToSplit(dataSet)
        bestFeatLabel = labels[bestFeat]
        myTree = {bestFeatLabel:{}}
        del(labels[bestFeat])
        featValues = [example[bestFeat] for example in dataSet]
        uniqueVals = set(featValues)
        for value in uniqueVals:
            subLabels = labels[:]
            myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
        return myTree

    第一个函数为选出出现次数最多的分类名称。

    第二个函数式建立决策树,也就是今天我想说的最关键的部分的代码,我们可以发现这是一个递归函数,首先我来说明跳出递归的条件,也就是类别完全相同时跳出递归,或者我们将所有的特征已经用尽则跳出递归。我们不难发现,第一个if是第一种情况,第二个if对应第二种情况。

    然后我们来处理不是这两种情况的情况,每次都利用前面的选择最优划分将数据进行划分,同时将该标签插入树中,并删除该标签,然后再次将剩下的数据和标签形成的新的结构放入函数中递归进行构建子决策树,这样一棵完整的决策树就建立了。

    下面给出程序运行的截图:(所谓有图有真相,无图无真相啊,我用的python的开发IDE是Eric5顺便推荐给大家)

    最后给大家3篇文章所有的代码:

    import math 
    import operator
    
    def calcShannonEnt(dataset):
        numEntries = len(dataset)
        labelCounts = {}
        for featVec in dataset:
            currentLabel = featVec[-1]
            if currentLabel not in labelCounts.keys():
                labelCounts[currentLabel] = 0
            labelCounts[currentLabel] +=1
            
        shannonEnt = 0.0
        for key in labelCounts:
            prob = float(labelCounts[key])/numEntries
            shannonEnt -= prob*math.log(prob, 2)
        return shannonEnt
        
    def CreateDataSet():
        dataset = [[1, 1, 'yes' ], 
                   [1, 1, 'yes' ], 
                   [1, 0, 'no'], 
                   [0, 1, 'no'], 
                   [0, 1, 'no']]
        labels = ['no surfacing', 'flippers']
        return dataset, labels
    
    def splitDataSet(dataSet, axis, value):
        retDataSet = []
        for featVec in dataSet:
            if featVec[axis] == value:
                reducedFeatVec = featVec[:axis]
                reducedFeatVec.extend(featVec[axis+1:])
                retDataSet.append(reducedFeatVec)
        
        return retDataSet
    
    def chooseBestFeatureToSplit(dataSet):
        numberFeatures = len(dataSet[0])-1
        baseEntropy = calcShannonEnt(dataSet)
        bestInfoGain = 0.0;
        bestFeature = -1;
        for i in range(numberFeatures):
            featList = [example[i] for example in dataSet]
            print(featList)
            uniqueVals = set(featList)
            print(uniqueVals)
            newEntropy =0.0
            for value in uniqueVals:
                subDataSet = splitDataSet(dataSet, i, value)
                prob = len(subDataSet)/float(len(dataSet))
                newEntropy += prob * calcShannonEnt(subDataSet)
            infoGain = baseEntropy - newEntropy
            if(infoGain > bestInfoGain):
                bestInfoGain = infoGain
                bestFeature = i
        return bestFeature
    
    def majorityCnt(classList):
        classCount ={}
        for vote in classList:
            if vote not in classCount.keys():
                classCount[vote]=0
            classCount[vote]=1
        sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True) 
        return sortedClassCount[0][0]
     
    
    def createTree(dataSet, labels):
        classList = [example[-1] for example in dataSet]
        if classList.count(classList[0])==len(classList):
            return classList[0]
        if len(dataSet[0])==1:
            return majorityCnt(classList)
        bestFeat = chooseBestFeatureToSplit(dataSet)
        bestFeatLabel = labels[bestFeat]
        myTree = {bestFeatLabel:{}}
        del(labels[bestFeat])
        featValues = [example[bestFeat] for example in dataSet]
        uniqueVals = set(featValues)
        for value in uniqueVals:
            subLabels = labels[:]
            myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
        return myTree
    
            
            
    myDat,labels = CreateDataSet()
    print(calcShannonEnt(myDat))
    
    print(splitDataSet(myDat, 1, 1))
    
    print(chooseBestFeatureToSplit(myDat))
    
    print(createTree(myDat, labels))
    



  • 相关阅读:
    web页面常用方法及INI文件的读取方法
    winform 三个Panel左右切换(panel里面填充图片)
    图片渐出轮播的效果
    Winform跑马灯——Graphics运用
    .net 3.5 新功能重写ToInt()方法
    style.display
    SQL: 分页SQL SQL2005函数分页!
    JS: 验证输入必须为数字
    Table 里面点标题会进行排序
    在Div中绑定数据
  • 原文地址:https://www.cnblogs.com/snake-hand/p/3167777.html
Copyright © 2011-2022 走看看