zoukankan      html  css  js  c++  java
  • 决策树代码《机器学习实战》

    22:45:17 2017-08-09

    KNN算法简单有效,可以解决很多分类问题。但是无法给出数据的含义,就是一顿计算向量距离,然后分类。

    决策树就可以解决这个问题,分类之后能够知道是问什么被划分到一个类。用图形画出来就效果更好了,这次没有学哪个画图的,下次。

    这里只涉及信息熵的计算,最佳分类特征的提取,决策树的构建。剪枝没有学,这里没有。

      1 # -*- oding: itf-8 -*-
      2 
      3 '''
      4 function: 《机器学习实战》决策树的代码,画图的部分没有写;
      5 note: 贴出来以后用方便一点~
      6 date: 2017.8.9
      7 '''
      8 
      9 from numpy import *
     10 from math import log
     11 import operator
     12 
     13 #计算香浓信息熵
     14 def calcuEntropy(dataSet):
     15     numOfEntries = len(dataSet)
     16     featVec = {}
     17     for data in dataSet:
     18         currentLabel = data[-1]
     19         if currentLabel not in featVec.keys():
     20             featVec[currentLabel] = 1
     21         else:
     22             featVec[currentLabel] += 1
     23     shannonEntropy = 0.0
     24     for feat in featVec.keys():
     25         prob = float(featVec[feat]) / numOfEntries
     26         shannonEntropy += -prob*log(prob, 2) 
     27     return shannonEntropy
     28 
     29 #产生数据集
     30 def loadDataSet():
     31     dataSet = [[1,1,'yes'],
     32                 [1,0,'no'],
     33                 [0,1,'no'],
     34                 [0,1,'no']]
     35     labels = ['no surfacing', 'flippers']
     36     return dataSet, labels
     37 
     38 '''
     39 function: split the dataset
     40 return: 基于划分特征划分之后我们想要的那部分集合
     41 parameters: dataSet: 数据集,axis: 要划分的特征, value:要返回的集合的axis特征值
     42 '''
     43 def splitDataSet(dataSet, axis, value):
     44     retDataSet = [] #防止原始的数据集被修改
     45     for featVec in dataSet:
     46         if featVec[axis] == value: #我们想要的数值存起来,一会返回
     47             reducedFeatVec = featVec[:axis]
     48             reducedFeatVec.extend(featVec[axis+1:])
     49             retDataSet.append(reducedFeatVec)
     50     return retDataSet
     51 
     52 '''
     53 function: 找出数据集中最佳的划分特征
     54 '''
     55 def chooseBestClassifyFeat(dataSet):
     56     numOfFeatures = len(dataSet[0]) - 1
     57     bestFeature = -1  #初始化最佳的划分特征
     58     baseInfoGain = 0.0 #信息增益
     59     baseEntropy = calcuEntropy(dataSet)
     60     for i in range(numOfFeatures):
     61         # if numOfFeatures == 1: #错了,只有一个特征不是只有一个类别
     62         #     print('only one feature')
     63         #     print(dataSet[0][0])
     64         #     return dataSet[0][0] #只有一个特征直接返回该特征
     65         featList = [example[i] for example in dataSet] #或者第i个特征所有的取值
     66         unicVals = set(featList) #不重复的第i个特征取值
     67         newEntropy = 0.0
     68         for value in unicVals:
     69             subDataSet = splitDataSet(dataSet, i, value)
     70 
     71             #计算划分之后各个子数据集的信息熵,然后累加就是这个划分的信息熵
     72             currentEntropy = calcuEntropy(subDataSet) 
     73             prob = float(len(subDataSet)) / len(dataSet)
     74             newEntropy += prob * currentEntropy
     75         newInfoGain = baseEntropy - newEntropy
     76         if newInfoGain > baseInfoGain:
     77             bestFeature = i
     78             baseInfoGain = newInfoGain
     79     return bestFeature 
     80 
     81 '''
     82 function: 多数表决,当分类器用完所有属性,叶节点还是类别不统一的时候调用这个函数
     83 arg: labelList 类别标签列表
     84 '''
     85 def majorityCount(labelList):
     86     classCount = {}
     87     for label in labelList:
     88         if label not in classCount.keys():
     89             classCount[label] = 0
     90         classCount[label] += 1
     91     sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1),reverse = True)
     92     print(sortedClassCount)
     93     return sortedClassCount[0][0]
     94 
     95 
     96 '''
     97 function: 递归的建造决策树
     98 arg: dataset: 数据集 labels: 代表特征的标签,起始算法不需要,比如fippers代表第一个特征的意义
     99 '''
    100 def createTree(dataSet, labels):
    101     classList = [example[-1] for example in dataSet] #得到所有的类别
    102     if classList.count(classList[0]) == len(classList): #只有一种类别,直接返回
    103         return classList[0]
    104     if len(dataSet[0]) == 1: #特征属性用完了但是还没有完全分开,多数表决
    105         return majorityCount(classList)
    106     bestFeat = chooseBestClassifyFeat(dataSet)
    107     print('bestFeat = ' + str(bestFeat))
    108     bestFeatLabel = labels[bestFeat]
    109     del(labels[bestFeat]) #删除这次使用的特征
    110     featValues = [example[bestFeat] for example in dataSet]
    111     myTree = {bestFeatLabel: {}}
    112     unicVals = set(featValues)
    113     for value in unicVals:
    114         labelCopy = labels[:]
    115         subDataSet = splitDataSet(dataSet, bestFeat, value)
    116         myTree[bestFeatLabel][value] = createTree(subDataSet, labelCopy)
    117     return myTree
    118 
    119 '''
    120 function: 用决策树进行分类
    121 arg: inputTree: 训练好的决策树,featLabels: 特征标签,testVec: 待分类的向量
    122 '''
    123 def classify(inputTree, featLabel, testVec):
    124     firstStr = list(inputTree.keys())[0] #python3 dict,.keys()不支持索引,必须转换一下
    125     secondDict = inputTree[firstStr] #second tree
    126     featIndex = featLabel.index(firstStr) #可利用index函数找到这个特征标签对饮过的特征位置
    127     for key in secondDict.keys():
    128         if testVec[featIndex] == key:
    129             if type(secondDict[key]).__name__ == 'dict': #说明下面不是叶子节点,继续分类
    130                 classLabel = classify(secondDict[key], featLabel, testVec)
    131             else:
    132                 classLabel = secondDict[key] #到达叶子节点,直接返回类别标签
    133     return classLabel
    134 
    135 '''
    136 function: 使用pickle模块持久化存储决策树
    137 note:
    138 '''
    139 def storeTree(inputTree, filename):
    140     import pickle
    141     fw = open(filename, 'wb')
    142     pickle.dump(inputTree, fw)
    143     fw.close()
    144 
    145 '''
    146 function: 从本地文件中读取决策树
    147 '''
    148 def grabTree(filename):
    149     import pickle
    150     fr = open(filename,'rb')
    151     return pickle.load(fr)
    152 
    153 #测试信息熵的计算
    154 dataSet, labels = loadDataSet()
    155 shannon = calcuEntropy(dataSet)
    156 print(shannon)
    157 
    158 #测试数据集分割
    159 print(dataSet)
    160 retDataSet = splitDataSet(dataSet, 1, 1)
    161 print(retDataSet)
    162 retDataSet = splitDataSet(dataSet, 1, 0)
    163 print(retDataSet)
    164 
    165 #寻找最佳的划分特征
    166 bestFeature = chooseBestClassifyFeat(dataSet)
    167 print(bestFeature)
    168 
    169 #测试多数表决
    170 out = majorityCount([1,1,2,2,2,1,2,2])
    171 print(out)
    172 
    173 #创建决策大叔
    174 myTree = createTree(dataSet, labels)
    175 print(myTree)
    176 
    177 #测试分类器
    178 dataSet, labels = loadDataSet()
    179 classLabel = classify(myTree, labels, [0,1])
    180 print(classLabel)
    181 classLabel = classify(myTree, labels, [1,1])
    182 print(classLabel)
    183 
    184 #持久化存储决策树
    185 storeTree(myTree, 'classifierStorage.txt')
    186 outTree = grabTree('classifierStorage.txt')
    187 print(outTree)
  • 相关阅读:
    [文字雲產生器] Tagxedo 把文字串成雲、變成畫,印在 T-Shirt、馬克杯、詩袋….
    python学习(六)
    根据URL地址获取域名
    python学习(五)
    Linux下查看Mysql数据库端口的方法
    python学习(四)
    python学习(三)
    python学习(二)
    Java String删除字符串中间的某部分
    Spring的一个入门例子
  • 原文地址:https://www.cnblogs.com/robin2ML/p/7331008.html
Copyright © 2011-2022 走看看