zoukankan      html  css  js  c++  java
  • 《机器学习实战》程序清单3-4 创建树的函数代码

    有点乱,等我彻底想明白时再来整理清楚。

    from math import log
    import operator
    
    def calcShannonEnt(dataSet):
        numEntries = len(dataSet)
        #print("样本总数:" + str(numEntries))
    
        labelCounts = {} #记录每一类标签的数量
    
        #定义特征向量featVec
        for featVec in dataSet:
            
            currentLabel = featVec[-1] #最后一列是类别标签
    
            if currentLabel not in labelCounts.keys():
                labelCounts[currentLabel] = 0;
    
            labelCounts[currentLabel] += 1 #标签currentLabel出现的次数
            #print("当前labelCounts状态:" + str(labelCounts))
    
        shannonEnt = 0.0
    
        for key in labelCounts:
            
            prob = float(labelCounts[key]) / numEntries #每一个类别标签出现的概率
    
            #print(str(key) + "类别的概率:" + str(prob))
            #print(prob * log(prob, 2) )
            shannonEnt -= prob * log(prob, 2) 
            #print("熵值:" + str(shannonEnt))
    
        return shannonEnt
    
    def createDataSet():
        dataSet = [
            [1, 1, 'yes'],
            [1, 1, 'yes'],
            [1, 0, 'no'],
            [0, 1, 'no'],
            [0, 1, 'no'],
            #以下随意添加,用于测试熵的变化,越混乱越冲突,熵越大
            # [1, 1, 'no'],
            # [1, 1, 'no'],
            # [1, 1, 'no'],
            # [1, 1, 'no'],
            #[1, 1, 'maybe'],
            # [1, 1, 'maybe1']
            # 用下面的8个比较极端的例子看得会更清楚。如果按照这个规则继续增加下去,熵会继续增大。
            # [1,1,'1'],
            # [1,1,'2'],
            # [1,1,'3'],
            # [1,1,'4'],
            # [1,1,'5'],
            # [1,1,'6'],
            # [1,1,'7'],
            # [1,1,'8'],
    
            # 这是另一个极端的例子,所有样本的类别是一样的,有序,不混乱,此时熵为0
            # [1,1,'2'],
            # [1,1,'1'],
            # [1,1,'1'],
            # [1,1,'1'],
            # [1,1,'1'],
            # [1,1,'1'],
            # [1,1,'1'],
            # [1,1,'1'],        
        ]
    
        #print("dataSet[0]:" + str(dataSet[0]))
        #print(dataSet)
    
        labels = ['no surfacing', 'flippers']
    
        return dataSet, labels
    
    def testCalcShannonEnt():
    
        myDat, labels = createDataSet()
        #print(calcShannonEnt(myDat))
    
    def splitDataSet(dataSet, axis, value):
        retDataSet = []
        for featVec in dataSet:
            #print("featVec:" + str(featVec))
            #print("featVec[axis]:" + str(featVec[axis]))
            if featVec[axis] == value:
                reduceFeatVec = featVec[:axis]
                #print(reduceFeatVec)
                reduceFeatVec.extend(featVec[axis + 1:])
                #print('reduceFeatVec:' + str(reduceFeatVec))
                retDataSet.append(reduceFeatVec)
        #print("retDataSet:" + str(retDataSet))
        return retDataSet
    
    def testSplitDataSet():
        myDat,labels = createDataSet()
        #print(myDat)
        a = splitDataSet(myDat, 0, 0)
        #print(a)
    
    
    def chooseBestFeatureToSplit(dataSet):
        numFeatures = len(dataSet[0]) - 1 #减掉类别列,剩2列
        #print("特征数量:" + str(numFeatures))
    
        baseEntropy = calcShannonEnt(dataSet)
        #print("基础熵:" + str(baseEntropy))
    
        bestInfoGain = 0.0;
        bestFeature  = -1
    
        #numFeatures==2
        for i in range(numFeatures):
            #print("i的值" + str(i))
            featList = [example[i] for example in dataSet];
            #print("featList:" + str(featList))
    
            #在列表中创建集合是Python语言得到列表中唯一元素值的最快方法
            #集合对象是一组无序排列的可哈希的值。集合化,收缩
            #[1, 0, 1, 1, 1, 1]创建集合后,变为{0,1}
            uniqueVals = set(featList) 
            #print("uniqueVals" + str(uniqueVals))
    
            newEntropy = 0.0
            #uniqueVals=={0,1}
            for value in uniqueVals:
                subDataSet = splitDataSet(dataSet, i, value)
                #print("subDataSet:" + str(subDataSet))
                prob = len(subDataSet) / float(len(dataSet))
                
                #print("subDataSet:" + str(subDataSet))
                #print("subDataSet的长度:" + str(len(subDataSet)))
                newEntropy += prob * calcShannonEnt(subDataSet)
                #print("newEntropy:" + str(newEntropy))
    
            #信息增益,新序列熵越小,增益越大,最终目标是把最大的增益找出来
            infoGain = baseEntropy - newEntropy 
            #print("infoGain:" + str(infoGain))
            #print("bestInfoGain:" + str(bestInfoGain))
    
    
            if(infoGain > bestInfoGain):
                bestInfoGain = infoGain
                bestFeature = i
    
        #print("bestFeature:" + str(bestFeature))
        return bestFeature
                
        
    def testChooseBestFeatureToSplit():
        myDat, labels = createDataSet()
        chooseBestFeatureToSplit(myDat)
    
    '''
    输入:类别列表     
    输出:类别列表中多数的类,即多数表决
    这个函数的作用是返回字典中出现次数最多的value对应的key,也就是输入list中出现最多的那个值
    '''
    def majorityCnt(classList):
        classCount={}
        for vote in classList:
            if vote not in classCount.keys(): 
                classCount[vote] = 0
            classCount[vote] += 1
     
         #key=operator.itemgetter(0)或key=operator.itemgetter(1),决定以字典的键排序还是以字典的值排序
         #0以键排序,1以值排序
         #reverse(是否反转)默认是False,reverse == true 则反转由大到小排列
    
        sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
    
        print(sortedClassCount)
    
        return sortedClassCount[0][0]
    def testMajorityCnt():
         list1 = ['a','b','a','a','b','c','d','d','d','e','a','a','a','a','c','c','c','c','c','c','c','c']
        
         print(majorityCnt(list1))
    
    global n
    n=0
    
    def createTree(dataSet, labels):
        
        global n
        print("=================createTree"+str(n)+" begin=============")
        n += 1
        print(n)
    
        classList = [example[-1] for example in dataSet]
    
        print("" + str(n) + "次classList:" + str(classList))
        print("此时列表中的第1个元素为" + str(classList[0]) + ",数量为:" + str(classList.count(classList[0])) + ",列表总长度为:" + str(len(classList)))
        
        print("列表中"+str(classList[0])+"的数量:",classList.count(classList[0]))
        print("列表的长度:", len(classList))
    
        if classList.count(classList[0])== len(classList):
            print("判断结果为:所有类别相同,停止本组划分")
        else:
            print("判断结果为:类别不相同")
    
         #列表中有n个元素,并且n个都一致,则停止递归 
        if classList.count(classList[0]) == len(classList):
             return classList[0]
    
        print("dataSet[0]:" + str(dataSet[0]))
    
        if len(dataSet[0]) == 1:
            print("启动多数表决")  #书中的示例样本集合没有触发
            return majorityCnt(classList)
    
        bestFeat = chooseBestFeatureToSplit(dataSet)
        bestFeatLabel = labels[bestFeat]
        print("bestFeat:" +str(bestFeat))
        print("bestFeatLabel:" + str(bestFeatLabel))
    
        myTree = {bestFeatLabel:{}}
        print("当前树状态:" + str(myTree))
    
        print("当前标签集合:" + str(labels))
        print("准备删除" + labels[bestFeat])
        del(labels[bestFeat])
        print("已删除")
        print("删除元素后的标签集合:" + str(labels))
    
        featValues = [example[bestFeat] for example in dataSet]
        print("featValues:",featValues)
    
        uniqueVals = set(featValues)
        print("uniqueVals:", uniqueVals) #{0,1}
    
        k = 0
        print("********开始循环******")
        for value in uniqueVals:
    
            k += 1
            print("",k,"次循环")
            subLabels = labels[:]
            print("传入参数:")
            print("        --待划分的数据集:",dataSet)
            print("        --划分数据集的特征:", bestFeat)
            print("        --需要返回的符合特征值:", value)
            splited = splitDataSet(dataSet, bestFeat, value)
            print("splited:", str(splited))
            myTree[bestFeatLabel][value] = createTree(splited, subLabels)  #递归调用
        print("*******结束循环*****")
    
        print("=================createTree"+str(n)+" end=============")
        return myTree
         
    def testCreateTree():
         
         myDat,labels = createDataSet();
         myTree = createTree(myDat, labels);
         print("============testCreateTree=============")
         print(myTree)
    
    if __name__ == '__main__':
        #测试输出信息熵
        #testCalcShannonEnt()
    
        #测试拆分结果集
        #testSplitDataSet()
    
        #选择最好的特征值
        #testChooseBestFeatureToSplit()
     
        #testMajorityCnt()
    
        testCreateTree()
    
    
        
  • 相关阅读:
    亚马逊云服务器VPS Amazon EC2 免费VPS主机配置CentOS及其它内容
    Linux + Mono 目前已经支持Entity Framework 6.1
    CentOS上 Mono 3.2.8运行ASP.NET MVC4经验
    Linux CentOS下如何确认MySQL服务已经启动
    C#使用Timer.Interval指定时间间隔与指定时间执行事件
    MySQL数据库有外键约束时使用truncate命令的办法
    C++中字符和字符串的读取与使用
    结构体的运算符重载
    P1358 扑克牌
    P1284 三角形牧场
  • 原文地址:https://www.cnblogs.com/Sabre/p/8415124.html
Copyright © 2011-2022 走看看