zoukankan      html  css  js  c++  java
  • [置顶] ID3算法的python实现

    这篇文章的内容接着http://blog.csdn.net/xueyunf/article/details/9214727的内容,所有还有部分函数在http://blog.csdn.net/xueyunf/article/details/9212827中,由于这个算法需要理解的内容比较多,所以我分成了3篇分别介绍,因为自己也是用了3天的时间才理解了这一经典算法。当然很犀利的童鞋也许很短时间就理解了这一算法,那么这篇文章也就不适合你了,可以跳过了,读了后不会有太多收获的。

    下面我就贴出代码来,为初学者提示一点东西:

    def majorityCnt(classList):
        classCount ={}
        for vote in classList:
            if vote not in classCount.keys():
                classCount[vote]=0
            classCount[vote]=1
        sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True) 
        return sortedClassCount[0][0]
     
    
    def createTree(dataSet, labels):
        classList = [example[-1] for example in dataSet]
        if classList.count(classList[0])==len(classList):
            return classList[0]
        if len(dataSet[0])==1:
            return majorityCnt(classList)
        bestFeat = chooseBestFeatureToSplit(dataSet)
        bestFeatLabel = labels[bestFeat]
        myTree = {bestFeatLabel:{}}
        del(labels[bestFeat])
        featValues = [example[bestFeat] for example in dataSet]
        uniqueVals = set(featValues)
        for value in uniqueVals:
            subLabels = labels[:]
            myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
        return myTree

    第一个函数为选出出现次数最多的分类名称。

    第二个函数式建立决策树,也就是今天我想说的最关键的部分的代码,我们可以发现这是一个递归函数,首先我来说明跳出递归的条件,也就是类别完全相同时跳出递归,或者我们将所有的特征已经用尽则跳出递归。我们不难发现,第一个if是第一种情况,第二个if对应第二种情况。

    然后我们来处理不是这两种情况的情况,每次都利用前面的选择最优划分将数据进行划分,同时将该标签插入树中,并删除该标签,然后再次将剩下的数据和标签形成的新的结构放入函数中递归进行构建子决策树,这样一棵完整的决策树就建立了。

    下面给出程序运行的截图:(所谓有图有真相,无图无真相啊,我用的python的开发IDE是Eric5顺便推荐给大家)

    最后给大家3篇文章所有的代码:

    import math 
    import operator
    
    def calcShannonEnt(dataset):
        numEntries = len(dataset)
        labelCounts = {}
        for featVec in dataset:
            currentLabel = featVec[-1]
            if currentLabel not in labelCounts.keys():
                labelCounts[currentLabel] = 0
            labelCounts[currentLabel] +=1
            
        shannonEnt = 0.0
        for key in labelCounts:
            prob = float(labelCounts[key])/numEntries
            shannonEnt -= prob*math.log(prob, 2)
        return shannonEnt
        
    def CreateDataSet():
        dataset = [[1, 1, 'yes' ], 
                   [1, 1, 'yes' ], 
                   [1, 0, 'no'], 
                   [0, 1, 'no'], 
                   [0, 1, 'no']]
        labels = ['no surfacing', 'flippers']
        return dataset, labels
    
    def splitDataSet(dataSet, axis, value):
        retDataSet = []
        for featVec in dataSet:
            if featVec[axis] == value:
                reducedFeatVec = featVec[:axis]
                reducedFeatVec.extend(featVec[axis+1:])
                retDataSet.append(reducedFeatVec)
        
        return retDataSet
    
    def chooseBestFeatureToSplit(dataSet):
        numberFeatures = len(dataSet[0])-1
        baseEntropy = calcShannonEnt(dataSet)
        bestInfoGain = 0.0;
        bestFeature = -1;
        for i in range(numberFeatures):
            featList = [example[i] for example in dataSet]
            print(featList)
            uniqueVals = set(featList)
            print(uniqueVals)
            newEntropy =0.0
            for value in uniqueVals:
                subDataSet = splitDataSet(dataSet, i, value)
                prob = len(subDataSet)/float(len(dataSet))
                newEntropy += prob * calcShannonEnt(subDataSet)
            infoGain = baseEntropy - newEntropy
            if(infoGain > bestInfoGain):
                bestInfoGain = infoGain
                bestFeature = i
        return bestFeature
    
    def majorityCnt(classList):
        classCount ={}
        for vote in classList:
            if vote not in classCount.keys():
                classCount[vote]=0
            classCount[vote]=1
        sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True) 
        return sortedClassCount[0][0]
     
    
    def createTree(dataSet, labels):
        classList = [example[-1] for example in dataSet]
        if classList.count(classList[0])==len(classList):
            return classList[0]
        if len(dataSet[0])==1:
            return majorityCnt(classList)
        bestFeat = chooseBestFeatureToSplit(dataSet)
        bestFeatLabel = labels[bestFeat]
        myTree = {bestFeatLabel:{}}
        del(labels[bestFeat])
        featValues = [example[bestFeat] for example in dataSet]
        uniqueVals = set(featValues)
        for value in uniqueVals:
            subLabels = labels[:]
            myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
        return myTree
    
            
            
    myDat,labels = CreateDataSet()
    print(calcShannonEnt(myDat))
    
    print(splitDataSet(myDat, 1, 1))
    
    print(chooseBestFeatureToSplit(myDat))
    
    print(createTree(myDat, labels))
    



  • 相关阅读:
    C# 文件过滤器Filter
    C#实现打印功能
    Entity Framework 批量插入很慢
    C#关于日期月天数和一年有多少周及某年某周时间段的计算
    lc.exe 已退出 代码为 1
    oracle获取本月第一天和最后一天及Oracle trunc()函数的用法
    [ASP.NET AJAX]How to register javascript functions after UpdatePanel updated
    PPC调用webservice精要
    Oracle的思维(4)Oracle的万能分页并不万能2
    Microsoft Tech ED 2006
  • 原文地址:https://www.cnblogs.com/snake-hand/p/3167777.html
Copyright © 2011-2022 走看看