zoukankan      html  css  js  c++  java
  • 《西瓜书》第四章,决策树3

    ▶ 决策树的再次重构。加入剪枝的时候发现原本的数据结构非常屎,重新制作了输的节点结构。

    ● 代码,主要加入了后剪枝函数,预剪枝可以借助现有函数的 createTree 来实现,没有单独写上来

      1 import numpy as np
      2 import matplotlib.pyplot as plt
      3 import operator
      4 import warnings
      5 
      6 warnings.filterwarnings("ignore")                           
      7 dataSize = 10000
      8 trainRatio = 0.3
      9 
     10 def dataSplit(x, y, part):                                                         
     11     return x[:part], y[:part],x[part:],y[part:]
     12 
     13 def createData(dim, option, kind, count = dataSize):                                # 创建数据集,给定属性维度,每属性取值数,类别数,样本数
     14     np.random.seed(103)        
     15     X = np.random.randint(option, size = [count,dim])
     16     if kind == 2:                           
     17         Y = ((3 - 2 * dim)*X[:,0] + 2 * np.sum(X[:,1:], 1) > 0.5).astype(int)       # 单独列出方便观察,可以合并到 else 中
     18     else: 
     19         randomVector = np.random.rand(dim)
     20         randomVector /= np.sum(randomVector)
     21         Y = (np.sum(X * randomVector,1) * kind / option).astype(int)                # 各类别不够均匀
     22     label = [ chr(i + 65) for i in range(dim) ]                                     # 属性名,用 'A','B','C',...
     23     
     24     #print(output)
     25     print("dim = %d, option = %d, kind = %d, dataSize = %d"%(dim, option, kind, count))
     26     kindCount = np.zeros(kind ,dtype = int)                                         # 各类别的占比
     27     for i in range(count):
     28         kindCount[Y[i]] += 1
     29     for i in range(kind):
     30         print("kind %d -> %4f"%(i, kindCount[i]/count))                                      
     31     return X, Y, label
     32 
     33 def plogp(x):                                                                       # 计算熵,把 nan 转成 0
     34     return np.nan_to_num( x * np.log2(x))    
     35 
     36 def calculateGain(table, alpha = 0):                                                # 计算增益
     37     sumC = np.sum(table, 0)                             
     38     sumR = np.sum(table, 1)                             
     39     sumA = np.sum(sumC)                                     
     40     if alpha == 0:                                                                  # 增益,化简过的式子
     41         return ( np.sum(plogp(table)) + plogp(sumA) - np.sum(plogp(sumC)) - np.sum(plogp(sumR)) ) / sumA
     42     elif alpha == 1:                                                                # 增益率
     43         return (np.sum(plogp(sumC)) - plogp(sumA)) / (np.sum(plogp(sumR)) - plogp(sumA)) - 1
     44     else:                                                                           # Gini 系数的倒数(因为 Gini 越小越好)
     45         return sumA / ( np.sum(sumR * (1 - np.sum(table * table, 0) / (sumR * sumR))) )
     46 
     47 def chooseFeature(dataX, dataY, label):                                             # 选择最优属性 
     48     count, dim = np.shape(dataX)
     49     maxGain = 0
     50     maxi = -1       
     51     kindTable = list(set(dataY))                                                    # 类别表
     52     for i in range(dim):                                                            
     53         valueList = list(set([ dataX[j][i] for j in range(count) ]))                # 属性取值表
     54         table = np.zeros([len(valueList),len(kindTable)])                           # 生成用于计算熵的表格
     55         for j in range(count):                                                      
     56             table[valueList.index(dataX[j][i]),kindTable.index(dataY[j])] += 1 
     57         gain = calculateGain(table)                                                 # 计算并记录最大增益的属性
     58         if (gain > maxGain):                                    
     59             maxGain = gain
     60             maxi = i 
     61     valueList = list(set([ dataX[j][maxi] for j in range(count) ]))                 # 最优属性的取值表
     62     return (maxi, valueList)
     63 
     64 def vote(kindList):                                                                 # 根据类别表进行投票
     65     kindCount = {}
     66     for i in kindList:
     67         if i not in kindCount.keys():
     68             kindCount[i] = 0
     69         kindCount[i] += 1    
     70     output = sorted(kindCount.items(),key=operator.itemgetter(1),reverse = True)    
     71     return output[0][0]                                                             
     72             
     73 def createTree(dataX, dataY, label):                                                # 以当前数据创建决策树
     74     #if dataX == []:                                                                # 数据为空情况,我们使用 value 和 valueList 来控制子类别,防止进入该分支
     75     #    return '?'    
     76     if len(dataX[0]) == 0:                                                          # 属性已经取完,用类别数据进行投票
     77         return {'default':vote(dataY)}
     78     if len(set(dataY)) == 1:                                                        # 所有元素属于一类,返回该类 
     79         return {'default':dataY[0]}
     80         
     81     bestFeature, valueList = chooseFeature(dataX, dataY, label)                     # 获取最佳属性索引及其取值表                   
     82     myTree = {'label': label[bestFeature], 'default':vote(dataY), 'valueList':valueList, 
     83               'child':[ [] for i in range(len(valueList)) ]}                        # 当前节点,包含选出的属性,节点默认类别,该属性取值表,子树表
     84                                                                     
     85     subLabel = label[:bestFeature] + label[bestFeature + 1:]                        # 属性表中挖掉刚选出的属性及其取值  
     86     childX = [ [] for i in range(len(valueList)) ]                                               
     87     childY = [ [] for i in range(len(valueList)) ]          
     88     for i in range(len(dataX)):
     89         line = dataX[i]
     90         index = valueList.index(line[bestFeature])
     91         childX[index].append(np.concatenate([line[:bestFeature],line[bestFeature + 1:]]))
     92         childY[index].append(dataY[i])        
     93     for i in range(len(valueList)):                                 # 计算子树
     94         myTree['child'][i] = createTree(childX[i], childY[i], subLabel)    
     95     return myTree                                                   # 返回当前节点
     96                                                                     # 预剪枝技术可以把函数 createTree 拆成两个函数,                                                                                                                                        
     97                                                                     # 第一个函数在选择当前节点最优属性后,计算“将当前节点作为叶节点”时的错误率,然后立即返回,
     98                                                                     # 将该值与其母节点的错误率进行比较,决定是否要添加该节点,
     99                                                                     # 若确定要添加,则进入第二个函数,确实添加该节点并返回(myTree = {...} 以后的部分),否则母节点继续返回
    100 
    101 def cutTree(nowTree, dataX, dataY, labelName):                      # 后剪枝,在已经生成的决策树上进行剪枝
    102     cutErrorRatio = np.sum((np.array(dataY) != nowTree['default']).astype(int))   # 将当前节点剪枝为叶节点时的错误个数
    103     if 'valueList' not in nowTree:                                  # 该节点就是叶节点,返回默认错误率
    104         return cutErrorRatio 
    105 
    106     count = len(nowTree['valueList'])                               # 按最优属性的取值进行分类,计算每种取值下的正确率
    107     childX = [ [] for j in range(count+1) ]                         # 多出来的一列存放取值不在 valueList 中的样本
    108     childY = [ [] for j in range(count+1) ]
    109     col = labelName.index(nowTree['label'])                         # 最佳属性是 dataX 的哪一列  
    110     for i in range(len(dataX)):                
    111         if dataX[i][col] not in nowTree['valueList']:               # 数据取到了不在决策树 valueLiat 中的值
    112             row = count            
    113         else:
    114             row = nowTree['valueList'].index(dataX[i][col])         # 该样本在该属性上的取值决定了该样本放在 childX 的哪一行里
    115         childX[row].append(dataX[i])  
    116         childY[row].append(dataY[i])  
    117 
    118     childErrorRatio = np.zeros(count+1)    
    119     for i in range(count):                                          # 对每个分支进行处理    
    120         if len(childX[i]) != 0:
    121             childErrorRatio[i] = cutTree(nowTree['child'][i], childX[i], childY[i], labelName)
    122     childErrorRatio[count] = np.sum((childX[count] != nowTree['default']).astype(int)) / len(childX[count])
    123         
    124     notCutErrorRatio = np.sum(np.array([ len(t) for t in childX ]) * childErrorRatio) # 每个子节点错误率关于进入该支路的样本数的加权平均
    125     
    126     if cutErrorRatio < notCutErrorRatio:                            # 剪枝错误率更低
    127         nowTree.pop('label')
    128         nowTree.pop('valueList')
    129         nowTree.pop('child')    
    130         return cutErrorRatio
    131     else:
    132         return notCutErrorRatio
    133 
    134 def test(dim, option, kind):                                                
    135     allX, allY, labelName = createData(dim, option, kind)            
    136     trainX, trainY, testX, testY = dataSplit(allX, allY, int(dataSize * trainRatio))    # 分离训练集 
    137     cutX, cutY, testX, testY = dataSplit(testX, testY, int(dataSize * trainRatio / 3))  # 分离剪枝用的验证集
    138     outputTree = createTree(trainX, trainY, labelName)                               
    139     cutTree(outputTree, cutX, cutY, labelName)                                   
    140 
    141     myResult = []                                                   # 存放测试结果
    142     
    143     for line in testX:                                               
    144         tempTree = outputTree                                       # 递归的搜索决策树
    145         while(True):
    146             if 'label' not in tempTree:
    147                 myResult.append(tempTree['default'])
    148                 break
    149             value = line[labelName.index(tempTree['label'])]        # 当前样本该属性的取值            
    150             if value not in tempTree['valueList']:
    151                 myResult.append(tempTree['default'])
    152                 break
    153             tempTree = tempTree['child'][tempTree['valueList'].index(value)] # 进入下一层搜索            
    154         
    155     errorRatio  = np.sum((np.array(myResult) != testY).astype(int)) / (dataSize*(1 - trainRatio)) # 计算分类错误率
    156     print("errorRatio = %4f"%errorRatio ) 
    157 
    158 if __name__=='__main__':    
    159     test(1, 2, 2)    
    160     test(1, 3, 2)    
    161     test(2, 2, 2)    
    162     test(2, 3, 2)
    163     test(3, 3, 2)           
    164     test(4, 3, 2)
    165     test(5, 4, 2)
    166     test(6, 5, 2)
    167     test(3, 3, 3)
    168     test(4, 4, 3)
    169     test(5, 4, 4)

    ● test(2, 3, 2) 在训练集数据量10000 时生成的决策树

    {'label': 'B', 
     'default': 1, 
     'valueList': [0, 1, 2], 
     'child': [{'default': 0}, 
               {'label': 'A', 
                'default': 1, 
                'valueList': [0, 1, 2], 
                'child': [{'default': 1}, 
                          {'default': 1}, 
                          {'default': 0}
                         ]
               }, 
               {'default': 1}
              ]          
    }

    ● 输出结果

    dim = 1, option = 2, kind = 2, dataSize = 10000
    kind 0 -> 0.498600
    kind 1 -> 0.501400
    errorRatio = 0.000000
    dim = 1, option = 3, kind = 2, dataSize = 10000
    kind 0 -> 0.339500
    kind 1 -> 0.660500
    errorRatio = 0.000000
    dim = 2, option = 2, kind = 2, dataSize = 10000
    kind 0 -> 0.497400
    kind 1 -> 0.502600
    errorRatio = 0.000000
    dim = 2, option = 3, kind = 2, dataSize = 10000
    kind 0 -> 0.446700
    kind 1 -> 0.553300
    errorRatio = 0.000000
    dim = 3, option = 3, kind = 2, dataSize = 10000
    kind 0 -> 0.444400
    kind 1 -> 0.555600
    errorRatio = 0.000000
    dim = 4, option = 3, kind = 2, dataSize = 10000
    kind 0 -> 0.452500
    kind 1 -> 0.547500
    errorRatio = 0.000000
    dim = 5, option = 4, kind = 2, dataSize = 10000
    kind 0 -> 0.468000
    kind 1 -> 0.532000
    errorRatio = 0.009000
    dim = 6, option = 5, kind = 2, dataSize = 10000
    kind 0 -> 0.464400
    kind 1 -> 0.535600
    errorRatio = 0.067286
    dim = 3, option = 3, kind = 3, dataSize = 10000
    kind 0 -> 0.485300
    kind 1 -> 0.476600
    kind 2 -> 0.038100
    errorRatio = 0.000000
    dim = 4, option = 4, kind = 3, dataSize = 10000
    kind 0 -> 0.405300
    kind 1 -> 0.568100
    kind 2 -> 0.026600
    errorRatio = 0.000000
    dim = 5, option = 4, kind = 4, dataSize = 10000
    kind 0 -> 0.207800
    kind 1 -> 0.584400
    kind 2 -> 0.207200
    kind 3 -> 0.000600
    errorRatio = 0.008571
  • 相关阅读:
    类的关联关系
    VisualStudio.DTE 对象可以通过检索 GetService() 方法
    openssl 安装
    反射的效率
    Ascll
    关于JavaScript 原型的理解
    asp.net MVC 学习笔记
    CSS3样式
    List<T>转DataTable
    SQL中的多表联查(SELECT DISTINCT 语句)
  • 原文地址:https://www.cnblogs.com/cuancuancuanhao/p/11182579.html
Copyright © 2011-2022 走看看