zoukankan      html  css  js  c++  java
  • 决策树分类算法原理分析与代码实现

    前言

      本文详细介绍机器学习分类算法中的决策树算法,并全面详解如何构造,表示,保存决策树,以及如何使用决策树进行分类等等问题。

      为了全面的理解学习决策树,本文篇幅较长,请耐心阅读。

    算法原理

      每次依据不同的特征信息对数据集进行划分,划分的最终结果是一棵树。

      该树的每个子树存放一个划分集,而每个叶节点则表示最终分类结果,这样一棵树被称为决策树。

      决策树建好之后,带着目标对象按照一定规则遍历这个决策树就能得到最终的分类结果。

      该算法可以分为两大部分:

        1. 构建决策树部分

        2. 使用决策树分类部分

      其中,第一部分是重点难点。

    决策树构造伪代码

     1 # ==============================================
     2 # 输入:
     3 #        数据集
     4 # 输出:
     5 #        构造好的决策树(也即训练集)
     6 # ==============================================
     7 def 创建决策树:
     8     '创建决策树'
     9     
    10     if (数据集中所有样本分类一致):
    11         创建携带类标签的叶子节点
    12     else:
    13         寻找划分数据集的最好特征
    14         根据最好特征划分数据集
    15         for 每个划分的数据集:
    16             创建决策子树(递归方式)

    核心问题一:依据什么划分数据集

      可采用ID3算法思路:如果以某种特种特征来划分数据集,会导致数据集发生最大程度的改变,那么就使用这种特征值来划分。

      那么又该如何衡量数据集的变化程度呢?

      可采用熵来进行衡量。这个字读作di,第二声,不要读成shang啊,哈哈!

      它用来衡量信息集的无序程度,其计算公式如下:

      

      其中:

      1. x是指分类。要注意决策树的分类是离散的。

      2. P(x)是指任一样本为该分类的概率

      显然,与原数据集相比,熵差最大的划分集就是最优划分集。

      对数据集求熵的代码如下:  

     1 # ==============================================
     2 # 输入:
     3 #        dataSet: 数据集文件名(含路径)
     4 # 输出:
     5 #        shannonEnt: 输入数据集的香农熵
     6 # ==============================================
     7 def calcShannonEnt(dataSet):
     8     '计算香农熵'
     9     
    10     # 数据集个数
    11     numEntries = len(dataSet)
    12     # 标签集合
    13     labelCounts = {}
    14     for featVec in dataSet:     # 行遍历数据集
    15         # 当前标签
    16         currentLabel = featVec[-1]
    17         # 加入标签集合
    18         if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0
    19         labelCounts[currentLabel] += 1
    20         
    21     # 计算当前数据集的香农熵并返回
    22     shannonEnt = 0.0
    23     for key in labelCounts:
    24         prob = float(labelCounts[key])/numEntries
    25         shannonEnt -= prob * log(prob,2)
    26         
    27     return shannonEnt

      可用如下函数创建测试数据集并对其求熵:

     1 # ==============================================
     2 # 输入:
     3 #
     4 # 输出:
     5 #        dataSet: 测试数据集列表
     6 # ==============================================
     7 def createDataSet():
     8     '创建测试数据集'
     9     
    10     dataSet = [[1, 1, 'yes'],
    11                [1, 1, 'yes'],
    12                [1, 0, 'no'],
    13                [0, 1, 'no'],
    14                [0, 1, 'no']]
    15     
    16     return dataSet
    17 
    18 def test():
    19     '测试'
    20     
    21     # 创建测试数据集
    22     myDat = createDataSet()
    23     # 求出其熵并打印
    24     print calcShannonEnt(myDat)

      运行结果如下:

      

      如果我们修改测试数据集的某些数据,让其看起来显得混乱点,则得到的熵的值会更大。

      还有其他描述集合无序程度的方法,比如说基尼不纯度等等,这里就不再讨论了。

    核心问题二:如何划分数据集

      这涉及到一些细节上面的问题了,比如:每次划分是否需要剔除某些字段?如何对各种划分所得的熵差进行比较并进行最优划分等等。

      首先是具体划分函数:

     1 # ==============================================
     2 # 输入:
     3 #        dataSet: 训练集文件名(含路径)
     4 #        axis: 用于划分的特征的列数
     5 #        value: 划分值
     6 # 输出:
     7 #        retDataSet: 划分后的子列表
     8 # ==============================================
     9 def splitDataSet(dataSet, axis, value):
    10     '数据集划分'
    11     
    12     # 划分结果
    13     retDataSet = []
    14     for featVec in dataSet:     # 逐行遍历数据集
    15         if featVec[axis] == value:      # 如果目标特征值等于value
    16             # 抽取掉数据集中的目标特征值列
    17             reducedFeatVec = featVec[:axis]
    18             reducedFeatVec.extend(featVec[axis+1:])
    19             # 将抽取后的数据加入到划分结果列表中
    20             retDataSet.append(reducedFeatVec)
    21             
    22     return retDataSet

      然后是选择最优划分函数:

     1 # ===============================================
     2 # 输入:
     3 #        dataSet: 数据集
     4 # 输出:
     5 #        bestFeature: 和原数据集熵差最大划分对应的特征的列号
     6 # ===============================================
     7 def chooseBestFeatureToSplit(dataSet):
     8     '选择最佳划分方案'
     9     
    10     # 特征个数
    11     numFeatures = len(dataSet[0]) - 1
    12     # 原数据集香农熵
    13     baseEntropy = calcShannonEnt(dataSet)
    14     # 暂存最大熵增量
    15     bestInfoGain = 0.0; 
    16     # 和原数据集熵差最大的划分对应的特征的列号
    17     bestFeature = -1
    18     
    19     for i in range(numFeatures):    # 逐列遍历数据集
    20         # 获取该列所有特征值
    21         featList = [example[i] for example in dataSet]
    22         # 将特征值列featList的值唯一化并保存到集合uniqueVals
    23         uniqueVals = set(featList)
    24         
    25         # 新划分法香农熵
    26         newEntropy = 0.0
    27         # 计算该特征划分下所有划分子集的香农熵,并叠加。
    28         for value in uniqueVals:    # 遍历该特征列所有特征值   
    29             subDataSet = splitDataSet(dataSet, i, value)
    30             prob = len(subDataSet)/float(len(dataSet))
    31             newEntropy += prob * calcShannonEnt(subDataSet)
    32         
    33         # 保存所有划分法中,和原数据集熵差最大划分对应的特征的列号。
    34         infoGain = baseEntropy - newEntropy
    35         if (infoGain > bestInfoGain):
    36             bestInfoGain = infoGain
    37             bestFeature = i
    38             
    39     return bestFeature

      用下面代码测试之:

    1 def test():
    2     '测试'
    3     
    4     myDat = createDataSet()
    5     print chooseBestFeatureToSplit(myDat)

      得到的结果是0:

      

      而上面的代码也看到,测试数据集为:

    1     dataSet = [[1, 1, 'yes'],
    2                [1, 1, 'yes'],
    3                [1, 0, 'no'],
    4                [0, 1, 'no'],
    5                [0, 1, 'no']]

      显然,按照第0列特征划分会更加合理,区分度更大。

    核心问题三:如何具体实现树结构

      通过对前面两个问题的分析,划分数据集这一块已经清楚明了了。

      那么如何用这些多层次的划分子集搭建出一个树结构呢?这部分更多涉及到编程技巧,某种程度上来说,就是用Python实现树的问题。

      在Python中,可以用字典来具体实现树:

        字典的键存放节点信息,值存放分支及子树/叶子节点信息。

      比如说对于下面这个树,用Python的字典表述就是:{'no surfacing' : {0, 'no', 1 : {'flippers' : {0 : 'no', 1 : 'yes'}}}}

      如下构建树部分代码。该函数调用后将形成决策树:

     1 # ===============================================
     2 # 输入:
     3 #        classList: 类标签集
     4 # 输出:
     5 #        sortedClassCount[0][0]: 出现次数最多的标签
     6 # ===============================================
     7 def majorityCnt(classList):
     8     '采用多数表决的方式求出classList中出现次数最多的类标签'
     9     
    10     classCount={}
    11     for vote in classList:
    12         if vote not in classCount.keys(): classCount[vote] = 0
    13         classCount[vote] += 1
    14     sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
    15     
    16     return sortedClassCount[0][0]
    17 
    18 # ===============================================
    19 # 输入:
    20 #        dataSet: 数据集
    21 #        labels: 划分标签集
    22 # 输出:
    23 #        myTree: 生成的决策树
    24 # ===============================================
    25 def createTree(dataSet,labels):
    26     '创建决策树'
    27     
    28     # 获得类标签列表
    29     classList = [example[-1] for example in dataSet]
    30     
    31     # 递归终止条件一:如果数据集内所有分类一致
    32     if classList.count(classList[0]) == len(classList): 
    33         return classList[0]
    34     
    35     # 递归终止条件二:如果所有特征都划分完毕
    36     if len(dataSet[0]) == 1:
    37         # 将它们都归为一类并返回
    38         return majorityCnt(classList)
    39     
    40     # 选择最佳划分特征
    41     bestFeat = chooseBestFeatureToSplit(dataSet)
    42     # 最佳划分对应的划分标签。注意不是分类标签
    43     bestFeatLabel = labels[bestFeat]
    44     # 构建字典空树
    45     myTree = {bestFeatLabel:{}}
    46     # 从划分标签列表中删掉划分后的元素
    47     del(labels[bestFeat])
    48     # 获取最佳划分对应特征的所有特征值
    49     featValues = [example[bestFeat] for example in dataSet]
    50     # 对特征值列表featValues唯一化,结果存于uniqueVals。
    51     uniqueVals = set(featValues)
    52     
    53     for value in uniqueVals:    # 逐行遍历特征值集合
    54         # 保存所有划分标签信息并将其伙同划分后的数据集传递进下一次递归
    55         subLabels = labels[:]
    56         myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels)
    57         
    58     return myTree 

      如下代码可用于测试函数是否正常执行:

     1 # ==============================================
     2 # 输入:
     3 #
     4 # 输出:
     5 #        用于测试的数据集和划分标签集
     6 # ==============================================
     7 def createDataSet():
     8     '创建测试数据集'
     9     
    10     dataSet = [[1, 1, 'yes'],
    11                [1, 1, 'yes'],
    12                [1, 0, 'no'],
    13                [0, 1, 'no'],
    14                [0, 1, 'no']]
    15     labels = ['no surfacing', 'flippers']
    16     
    17     return dataSet, labels
    18 
    19 def test():
    20     '测试'
    21     
    22     myDat, labels = createDataSet()
    23     myTree = createTree(myDat, labels)
    24     print myTree

      运行结果:

      

    使用Matplotlib绘制树形图

      当决策树构建好了以后,自然需要用一种方式来显示给开发人员。仅仅是一个字典表达式很难让人满意。

      因此,可采用Matplotlib来绘制树形图。

      这涉及到两方面的知识:

        1. 遍历树获取树的高度,叶子数等信息。

        2. Matplotlib绘制图像的一些API

      对于第一部分的任务,可以用递归的方式遍历字典树,从而获得树的相关信息。

      下面给出求树的叶子树及树高的函数:

     1 # ===============================================
     2 # 输入:
     3 #        myTree: 决策树
     4 # 输出:
     5 #        numLeafs: 决策树的叶子数
     6 # ===============================================
     7 def getNumLeafs(myTree):
     8     '计算决策树的叶子数'
     9     
    10     # 叶子数
    11     numLeafs = 0
    12     # 节点信息
    13     firstStr = myTree.keys()[0]
    14     # 分支信息
    15     secondDict = myTree[firstStr]
    16     
    17     for key in secondDict.keys():   # 遍历所有分支
    18         # 子树分支则递归计算
    19         if type(secondDict[key]).__name__=='dict':
    20             numLeafs += getNumLeafs(secondDict[key])
    21         # 叶子分支则叶子数+1
    22         else:   numLeafs +=1
    23         
    24     return numLeafs
    25 
    26 # ===============================================
    27 # 输入:
    28 #        myTree: 决策树
    29 # 输出:
    30 #        maxDepth: 决策树的深度
    31 # ===============================================
    32 def getTreeDepth(myTree):
    33     '计算决策树的深度'
    34     
    35     # 最大深度
    36     maxDepth = 0
    37     # 节点信息
    38     firstStr = myTree.keys()[0]
    39     # 分支信息
    40     secondDict = myTree[firstStr]
    41     
    42     for key in secondDict.keys():   # 遍历所有分支
    43         # 子树分支则递归计算
    44         if type(secondDict[key]).__name__=='dict':
    45             thisDepth = 1 + getTreeDepth(secondDict[key])
    46         # 叶子分支则叶子数+1
    47         else:   thisDepth = 1
    48         
    49         # 更新最大深度
    50         if thisDepth > maxDepth: maxDepth = thisDepth
    51         
    52     return maxDepth

      对于第二部分的任务 - 画树,其实本质就是画点和画线,下面给出基本的线画法:

     1 import matplotlib.pyplot as plt
     2 
     3 decisionNode = dict(boxstyle="sawtooth", fc="0.8")
     4 leafNode = dict(boxstyle="round4", fc="0.8")
     5 arrow_args = dict(arrowstyle="<-")
     6     
     7 # ==================================================
     8 # 输入:
     9 #        nodeTxt:    终端节点显示内容
    10 #        centerPt:    终端节点坐标
    11 #        parentPt:    起始节点坐标
    12 #        nodeType:    终端节点样式
    13 # 输出:
    14 #        在图形界面中显示输入参数指定样式的线段(终端带节点)
    15 # ==================================================
    16 def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    17     '画线(末端带一个点)'
    18         
    19     createPlot.ax1.annotate(nodeTxt, xy=parentPt,  xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction', va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )
    20         
    21 def createPlot():
    22     '绘制有向线段(末端带一个节点)并显示'
    23         
    24     # 新建一个图对象并清空
    25     fig = plt.figure(1, facecolor='white')
    26     fig.clf()
    27     # 设置1行1列个图区域,并选择其中的第1个区域展示数据。
    28     createPlot.ax1 = plt.subplot(111, frameon=False)
    29         
    30     # 画线(末端带一个节点)
    31     plotNode('decisionNode', (0.5, 0.1), (0.1, 0.5), decisionNode)
    32     plotNode('leafNode', (0.8, 0.1), (0.3, 0.8), leafNode)
    33         
    34     # 显示绘制结果
    35     plt.show()

      调用 createPlot 函数即可显示绘制结果:

      

      下面,将这两部分内容整合起来,写出最终绘制树的代码:

      1 import matplotlib.pyplot as plt
      2 
      3 decisionNode = dict(boxstyle="sawtooth", fc="0.8")
      4 leafNode = dict(boxstyle="round4", fc="0.8")
      5 arrow_args = dict(arrowstyle="<-")
      6     
      7 # ==================================================
      8 # 输入:
      9 #        nodeTxt:     终端节点显示内容
     10 #        centerPt:    终端节点坐标
     11 #        parentPt:    起始节点坐标
     12 #        nodeType:    终端节点样式
     13 # 输出:
     14 #        在图形界面中显示输入参数指定样式的线段(终端带节点)
     15 # ==================================================
     16 def plotNode(nodeTxt, centerPt, parentPt, nodeType):
     17     '画线(末端带一个点)'
     18         
     19     createPlot.ax1.annotate(nodeTxt, xy=parentPt,  xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction', va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )
     20 
     21 # =================================================================
     22 # 输入:
     23 #        cntrPt:      终端节点坐标
     24 #        parentPt:    起始节点坐标
     25 #        txtString:   待显示文本内容
     26 # 输出:
     27 #        在图形界面指定位置(cntrPt和parentPt中间)显示文本内容(txtString)
     28 # =================================================================
     29 def plotMidText(cntrPt, parentPt, txtString):
     30     '在指定位置添加文本'
     31     
     32     # 中间位置坐标
     33     xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
     34     yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
     35     
     36     createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)
     37 
     38 # ===================================
     39 # 输入:
     40 #        myTree:    决策树
     41 #        parentPt:  根节点坐标
     42 #        nodeTxt:   根节点坐标信息
     43 # 输出:
     44 #        在图形界面绘制决策树
     45 # ===================================
     46 def plotTree(myTree, parentPt, nodeTxt):
     47     '绘制决策树'
     48     
     49     # 当前树的叶子数
     50     numLeafs = getNumLeafs(myTree)
     51     # 当前树的节点信息
     52     firstStr = myTree.keys()[0]
     53     # 定位第一棵子树的位置(这是蛋疼的一部分)
     54     cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
     55     
     56     # 绘制当前节点到子树节点(含子树节点)的信息
     57     plotMidText(cntrPt, parentPt, nodeTxt)
     58     plotNode(firstStr, cntrPt, parentPt, decisionNode)
     59     
     60     # 获取子树信息
     61     secondDict = myTree[firstStr]
     62     # 开始绘制子树,纵坐标-1。        
     63     plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
     64       
     65     for key in secondDict.keys():   # 遍历所有分支
     66         # 子树分支则递归
     67         if type(secondDict[key]).__name__=='dict':
     68             plotTree(secondDict[key],cntrPt,str(key))
     69         # 叶子分支则直接绘制
     70         else:
     71             plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
     72             plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
     73             plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
     74      
     75     # 子树绘制完毕,纵坐标+1。
     76     plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
     77 
     78 # ==============================
     79 # 输入:
     80 #        myTree:    决策树
     81 # 输出:
     82 #        在图形界面显示决策树
     83 # ==============================
     84 def createPlot(inTree):
     85     '显示决策树'
     86     
     87     # 创建新的图像并清空 - 无横纵坐标
     88     fig = plt.figure(1, facecolor='white')
     89     fig.clf()
     90     axprops = dict(xticks=[], yticks=[])
     91     createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
     92     
     93     # 树的总宽度 高度
     94     plotTree.totalW = float(getNumLeafs(inTree))
     95     plotTree.totalD = float(getTreeDepth(inTree))
     96     
     97     # 当前绘制节点的坐标
     98     plotTree.xOff = -0.5/plotTree.totalW; 
     99     plotTree.yOff = 1.0;
    100     
    101     # 绘制决策树
    102     plotTree(inTree, (0.5,1.0), '')
    103     
    104     plt.show()
    105         
    106 def test():
    107     '测试'
    108     
    109     myDat, labels = createDataSet()
    110     myTree = createTree(myDat, labels)
    111     createPlot(myTree)

      运行结果如下图:

      

    关于决策树的存储

      这部分也很重要。

      生成一个决策树比较耗时间,谁也不想每次启动程序都重新进行机器学习吧。能否将学习结果 - 决策树保存到硬盘中去呢?

      答案是肯定的,以下两个函数分别实现了决策树的存储与打开:

     1 # ======================
     2 # 输入:
     3 #        myTree:    决策树
     4 # 输出:
     5 #        决策树文件
     6 # ======================
     7 def storeTree(inputTree,filename):
     8     '保存决策树'
     9     
    10     import pickle
    11     fw = open(filename,'w')
    12     pickle.dump(inputTree,fw)
    13     fw.close()
    14     
    15 # ========================
    16 # 输入:
    17 #        filename:    决策树文件名
    18 # 输出:
    19 #        pickle.load(fr):    决策树
    20 # ========================    
    21 def grabTree(filename):
    22     '打开决策树'
    23     
    24     import pickle
    25     fr = open(filename)
    26     return pickle.load(fr)

    使用决策树进行分类

      终于到了这一步,也是最终一步了。

      拿到需要分类的数据后,遍历决策树直至叶子节点,即可得到分类结果,是不是很简单呢?

      下面给出遍历及测试代码:

     1 # ========================
     2 # 输入:
     3 #        inputTree:    决策树文件名
     4 #        featLabels:    分类标签集
     5 #        testVec:        待分类对象
     6 # 输出:
     7 #        classLabel:    分类结果
     8 # ========================     
     9 def classify(inputTree,featLabels,testVec):
    10     '使用决策树分类'
    11     
    12     # 当前分类标签
    13     firstStr = inputTree.keys()[0]
    14     secondDict = inputTree[firstStr]
    15     # 找到当前分类标签在分类标签集中的下标
    16     featIndex = featLabels.index(firstStr)
    17     # 获取待分类对象中当前分类的特征值
    18     key = testVec[featIndex]
    19     
    20     # 遍历
    21     valueOfFeat = secondDict[key]
    22     
    23     # 子树分支则递归
    24     if isinstance(valueOfFeat, dict): 
    25         classLabel = classify(valueOfFeat, featLabels, testVec)
    26     # 叶子分支则返回结果
    27     else: classLabel = valueOfFeat
    28     
    29     return classLabel
    30 
    31 def test():
    32     '测试'
    33     
    34     myDat, labels = createDataSet()
    35     myTree = createTree(myDat, labels)
    36     # 再创建一次数据的原因是创建决策树函数会将labels值改动
    37     myDat, labels = createDataSet()
    38     print classify(myTree, labels, [1,1])

      运行结果如下:

      

      OK,一个完整的决策树使用例子就实现了。

    小结

      1. 本文演示的是最经典ID3决策树,但它在实际应用中存在过度匹配的问题。在以后的文章中会学习如何对决策树进行裁剪。

      2. 本文采用的ID3决策树算法只能用于标称型数据。对于数值型数据,需要使用Cart决策树构造算法。这个算法将在以后进行深入学习。

  • 相关阅读:
    css3阴影效果
    应该了解的9种CSS技巧
    position
    MyEclipse设置Java代码注释模板
    Struts2 常用的常量配置
    CSS 中文字体对应英文和Unicode编码
    MyEclipse使用前优化与配置
    MyEclipse 快捷键收集
    Ajax 调用WebServices之一 基本应用
    C#控制台显示进度条
  • 原文地址:https://www.cnblogs.com/scut-fm/p/4180630.html
Copyright © 2011-2022 走看看