zoukankan      html  css  js  c++  java
  • 李航——决策树代码

    # -*- coding: utf-8 -*-
    """
    Created on Tue May 15 15:28:42 2018
    
    @author: baochen
    """
    from math import log
    import numpy
    import operator
    
    def CalEnt(database):
        LabelNumber = len(database)
        LabelDic = {}
        
        for line in database:
            LineLabel = line[-1]
            if LineLabel not in LabelDic.keys():
                LabelDic[LineLabel] = 0
            LabelDic[LineLabel] += 1
            
            
        ShanonEnt = 0.0
        for key in LabelDic:
            prob = LabelDic[key]/LabelNumber
            ShanonEnt -= prob*log(prob,2) 
            
        return ShanonEnt
    
    def CreatDatabase():
        database =[[1,0,0,0,'no'],[1,0,0,1,'no'],[1,1,0,1,'yes'],[1,1,1,0,'yes'],[1,0,0,0,'no'],
                   [2,0,0,0,'no'],[2,0,0,1,'no'],[2,1,1,1,'yes'],[2,0,1,2,'yes'],[2,0,1,2,'no'],
                   [3,0,1,2,'yes'],[3,0,1,1,'no'],[3,1,0,1,'yes'],[3,1,0,2,'yes'],[3,0,0,0,'no']]
        return database
    #axis表示维度,value表示区别
    def SplitDatabase(database,axis,value):
        retDatabase = []
        for line in database:
            if line[axis] == value:
                retDatabase.append(line)            
        return retDatabase
    
    def ChooseBest(database):
        baseShanon = CalEnt(database)
        bestInformationGain = 0.0
        ConShanon = 0.0
        FeaNum = len(database[0][:])-1
        #print(FeaNum)
        BestChoose = -1
        LabelNum = len(database[:][0])
    
        
        for i in range(FeaNum):
            
            if i not in t:
                
                FeatList = [temp[i] for temp in database]
                PureFeatList = set(FeatList)
           
                for value in PureFeatList:
                    subdatabase = SplitDatabase(database,i,value)
                    prob = len(subdatabase)/float(len(database))
                    ConShanon -= prob*CalEnt(subdatabase) 
                    InformationGain = baseShanon - ConShanon
                #print(i)
                    if InformationGain > bestInformationGain:
                        bestInformationGain = InformationGain
                        BestChoose = i
                   # print(BestChoose)
        return BestChoose
                
    def majorityEnt(LabelList):
        LabelCount = {}
        for vote in LabelList:
            if vote not in LabelList.keys():
                LabelCount[vote] = 0
                
            LabelList[vote] += 1
            #sorted(iterable,cmp,key,reverse = true) 
            #第一个是迭代器,第二个是判断函数,第三个是分类数据,第四个是正序反序
            #因为第一个需要迭代器,所以我们生成迭代器,第二个不管,第三个operator.itemgetter表示按第一个域进行排序
        sortedLabelCount = sorted(LabelCount.iteritems(), key=operator.itemgetter(1), reverse=True)
        #字典也可以当做数组的处理方法来搞
        return sortedLabelCount[0][0]
                
                
    
    def CreatTree(database,label):
        LabelList = [x[-1] for x in database]
        #判断某一行是否全为某个数的方法,就是判断第一个数的个数是否等于该行的全部数目
        if LabelList.count(LabelList[0]) == len(LabelList):
            return LabelList[0]
        
        if len(database[0]) == 1:
            return majorityEnt(LabelList)
        
        
        Feature = ChooseBest(database)
        t.append(Feature)
        #print(t)
        
     #   print(Feature)
        BestLabel = label[Feature]
        print(BestLabel)
    
        
        
        MyTree = {BestLabel:{}}
        
        #del(label[Feature])
            #print(database)
        
        
        featValues = [x[Feature] for x in database]
        uniqueVals = set(featValues)
        for value in uniqueVals:
            subLabels = label
            MyTree[BestLabel][value] = CreatTree(SplitDatabase(database,Feature,value),subLabels)
            
        return MyTree
        
        
    
        
            
    
    
    def test():
        global t 
        t = []
        mydatabase = CreatDatabase()
        '''
        s = CalEnt(mydatabase)
        t=SplitDatabase(mydatabase,0,1)
        print(s)
        print(t)
        '''
        label= ['age','work','house','money']
        print(CreatTree(mydatabase,label))
        
    test()
    money
    age
    work
    house
    money
    
    {'money': {0: {'age': {1: {'work': {0: 'no', 1: 'yes'}}, 2: 'no', 3: 'no'}}, 1: {'house': {0: {'money': {'no': 'no', 'yes': 'yes'}}, 1: {'money': {'no': 'no', 'yes': 'yes'}}}}, 2: {'money': {'no': 'no', 'yes': 'yes'}}}}

    感觉最后面的地方有点乱了

    有空优化一下。

  • 相关阅读:
    垂死挣扎还是涅槃重生 -- Delphi XE5 公布会归来感想
    自考感悟,话谈备忘录模式
    [每日一题] OCP1z0-047 :2013-07-26 alter table set unused之后各种情况处理
    Java实现 蓝桥杯 算法提高 p1001
    Java实现 蓝桥杯 算法提高 拿糖果
    Java实现 蓝桥杯 算法提高 拿糖果
    Java实现 蓝桥杯 算法提高 求arccos值
    Java实现 蓝桥杯 算法提高 求arccos值
    Java实现 蓝桥杯 算法提高 因式分解
    Java实现 蓝桥杯 算法提高 因式分解
  • 原文地址:https://www.cnblogs.com/baochen/p/9046290.html
Copyright © 2011-2022 走看看