zoukankan      html  css  js  c++  java
  • 李航——决策树代码

    # -*- coding: utf-8 -*-
    """
    Created on Tue May 15 15:28:42 2018
    
    @author: baochen
    """
    from math import log
    import numpy
    import operator
    
    def CalEnt(database):
        LabelNumber = len(database)
        LabelDic = {}
        
        for line in database:
            LineLabel = line[-1]
            if LineLabel not in LabelDic.keys():
                LabelDic[LineLabel] = 0
            LabelDic[LineLabel] += 1
            
            
        ShanonEnt = 0.0
        for key in LabelDic:
            prob = LabelDic[key]/LabelNumber
            ShanonEnt -= prob*log(prob,2) 
            
        return ShanonEnt
    
    def CreatDatabase():
        database =[[1,0,0,0,'no'],[1,0,0,1,'no'],[1,1,0,1,'yes'],[1,1,1,0,'yes'],[1,0,0,0,'no'],
                   [2,0,0,0,'no'],[2,0,0,1,'no'],[2,1,1,1,'yes'],[2,0,1,2,'yes'],[2,0,1,2,'no'],
                   [3,0,1,2,'yes'],[3,0,1,1,'no'],[3,1,0,1,'yes'],[3,1,0,2,'yes'],[3,0,0,0,'no']]
        return database
    #axis表示维度,value表示区别
    def SplitDatabase(database,axis,value):
        retDatabase = []
        for line in database:
            if line[axis] == value:
                retDatabase.append(line)            
        return retDatabase
    
    def ChooseBest(database):
        baseShanon = CalEnt(database)
        bestInformationGain = 0.0
        ConShanon = 0.0
        FeaNum = len(database[0][:])-1
        #print(FeaNum)
        BestChoose = -1
        LabelNum = len(database[:][0])
    
        
        for i in range(FeaNum):
            
            if i not in t:
                
                FeatList = [temp[i] for temp in database]
                PureFeatList = set(FeatList)
           
                for value in PureFeatList:
                    subdatabase = SplitDatabase(database,i,value)
                    prob = len(subdatabase)/float(len(database))
                    ConShanon -= prob*CalEnt(subdatabase) 
                    InformationGain = baseShanon - ConShanon
                #print(i)
                    if InformationGain > bestInformationGain:
                        bestInformationGain = InformationGain
                        BestChoose = i
                   # print(BestChoose)
        return BestChoose
                
    def majorityEnt(LabelList):
        LabelCount = {}
        for vote in LabelList:
            if vote not in LabelList.keys():
                LabelCount[vote] = 0
                
            LabelList[vote] += 1
            #sorted(iterable,cmp,key,reverse = true) 
            #第一个是迭代器,第二个是判断函数,第三个是分类数据,第四个是正序反序
            #因为第一个需要迭代器,所以我们生成迭代器,第二个不管,第三个operator.itemgetter表示按第一个域进行排序
        sortedLabelCount = sorted(LabelCount.iteritems(), key=operator.itemgetter(1), reverse=True)
        #字典也可以当做数组的处理方法来搞
        return sortedLabelCount[0][0]
                
                
    
    def CreatTree(database,label):
        LabelList = [x[-1] for x in database]
        #判断某一行是否全为某个数的方法,就是判断第一个数的个数是否等于该行的全部数目
        if LabelList.count(LabelList[0]) == len(LabelList):
            return LabelList[0]
        
        if len(database[0]) == 1:
            return majorityEnt(LabelList)
        
        
        Feature = ChooseBest(database)
        t.append(Feature)
        #print(t)
        
     #   print(Feature)
        BestLabel = label[Feature]
        print(BestLabel)
    
        
        
        MyTree = {BestLabel:{}}
        
        #del(label[Feature])
            #print(database)
        
        
        featValues = [x[Feature] for x in database]
        uniqueVals = set(featValues)
        for value in uniqueVals:
            subLabels = label
            MyTree[BestLabel][value] = CreatTree(SplitDatabase(database,Feature,value),subLabels)
            
        return MyTree
        
        
    
        
            
    
    
    def test():
        global t 
        t = []
        mydatabase = CreatDatabase()
        '''
        s = CalEnt(mydatabase)
        t=SplitDatabase(mydatabase,0,1)
        print(s)
        print(t)
        '''
        label= ['age','work','house','money']
        print(CreatTree(mydatabase,label))
        
    test()
    money
    age
    work
    house
    money
    
    {'money': {0: {'age': {1: {'work': {0: 'no', 1: 'yes'}}, 2: 'no', 3: 'no'}}, 1: {'house': {0: {'money': {'no': 'no', 'yes': 'yes'}}, 1: {'money': {'no': 'no', 'yes': 'yes'}}}}, 2: {'money': {'no': 'no', 'yes': 'yes'}}}}

    感觉最后面的地方有点乱了

    有空优化一下。

  • 相关阅读:
    程序打印的日志哪里去了?结合slf4j来谈谈面向接口编程的重要性
    vue项目用npm安装sass包遇到的问题及解决办法
    nginx反向代理配置及常见指令
    你以为你以为的就是你以为的吗?记一次服务器点对点通知的联调过程
    jeecg逆向工程代码的生成及常见问题
    java注解
    终于有了,史上最强大的数据脱敏处理算法
    SpringBoot项目下的JUnit测试
    递归方法
    练习题
  • 原文地址:https://www.cnblogs.com/baochen/p/9046290.html
Copyright © 2011-2022 走看看