zoukankan      html  css  js  c++  java
  • 决策树系列三——基尼指数,减枝和

    #####################################################################################################

    -- coding: utf-8 --

    """
    Created on Tue Aug 14 17:36:57 2018

    @author: weixw
    """
    import numpy as np

    定义树结构,采用的二叉树,左子树:条件为true,右子树:条件为false

    leftBranch:左子树结点

    rightBranch:右子树结点

    col:信息增益最大时对应的列索引

    value:最优列索引下,划分数据类型的值

    results:分类结果

    summary:信息增益最大时样本信息

    data:信息增益最大时数据集

    class Tree:
    def init(self, leftBranch=None, rightBranch=None, col=-1, value=None, results=None, summary=None, data=None):
    self.leftBranch = leftBranch
    self.rightBranch = rightBranch
    self.col = col
    self.value = value
    self.results = results
    self.summary = summary
    self.data = data

    def __str__(self):
        print(u"列号:%d" % self.col)
        print(u"列划分值:%s" % self.value)
        print(u"样本信息:%s" % self.summary)
        return ""
    

    划分数据集

    def splitDataSet(dataSet, value, column):
    leftList = []
    rightList = []
    # 判断value是否是数值型
    if (isinstance(value, int) or isinstance(value, float)):
    # 遍历每一行数据
    for rowData in dataSet:
    # 如果某一行指定列值>=value,则将该行数据保存在leftList中,否则保存在rightList中
    if (rowData[column] >= value):
    leftList.append(rowData)
    else:
    rightList.append(rowData)
    # value为标称型
    else:
    # 遍历每一行数据
    for rowData in dataSet:
    # 如果某一行指定列值==value,则将该行数据保存在leftList中,否则保存在rightList中
    if (rowData[column] == value):
    leftList.append(rowData)
    else:
    rightList.append(rowData)
    return leftList, rightList

    统计标签类每个样本个数

    '''
    该函数是计算gini值的辅助函数,假设输入的dataSet为为['A', 'B', 'C', 'A', 'A', 'D'],
    则输出为['A':3,' B':1, 'C':1, 'D':1],这样分类统计dataSet中每个类别的数量
    '''

    def calculateDiffCount(dataSet):
    results = {}
    for data in dataSet:
    # data[-1] 是数据集最后一列,也就是标签类
    if data[-1] not in results:
    results.setdefault(data[-1], 1)
    else:
    results[data[-1]] += 1
    return results

    基尼指数公式实现

    def gini(dataSet):
    # 计算gini的值(Calculate GINI)
    # 数据所有行
    length = len(dataSet)
    # 标签列合并后的数据集
    results = calculateDiffCount(dataSet)
    imp = 0.0
    for i in results:
    imp += results[i] / length * results[i] / length
    return 1 - imp

    生成决策树

    '''算法步骤'''
    '''根据训练数据集,从根结点开始,递归地对每个结点进行以下操作,构建二叉决策树:
    1 设结点的训练数据集为D,计算现有特征对该数据集的信息增益。此时,对每一个特征A,对其可能取的
    每个值a,根据样本点对A >=a 的测试为“是”或“否”将D分割成D1和D2两部分,利用基尼指数计算信息增益。
    2 在所有可能的特征A以及它们所有可能的切分点a中,选择信息增益最大的特征及其对应的切分点作为最优特征
    与最优切分点,依据最优特征与最优切分点,从现结点生成两个子结点,将训练数据集依特征分配到两个子结点中去。
    3 对两个子结点递归地调用1,2,直至满足停止条件。
    4 生成CART决策树。
    '''''''''''''''''''''

    evaluationFunc= gini :采用的是基尼指数来衡量信息关注度

    def buildDecisionTree(dataSet, evaluationFunc=gini):
    # 计算基础数据集的基尼指数
    baseGain = evaluationFunc(dataSet)
    # 计算每一行的长度(也就是列总数)
    columnLength = len(dataSet[0])
    # 计算数据项总数
    rowLength = len(dataSet)
    # 初始化
    bestGain = 0.0 # 信息增益最大值
    bestValue = None # 信息增益最大时的列索引,以及划分数据集的样本值
    bestSet = None # 信息增益最大,听过样本值划分数据集后的数据子集
    # 标签列除外(最后一列),遍历每一列数据
    for col in range(columnLength - 1):
    # 获取指定列数据
    colSet = [example[col] for example in dataSet]
    # 获取指定列样本唯一值
    uniqueColSet = set(colSet)
    # 遍历指定列样本集
    for value in uniqueColSet:
    # 分割数据集
    leftDataSet, rightDataSet = splitDataSet(dataSet, value, col)
    # 计算子数据集概率,python3 "/"除号结果为小数
    prop = len(leftDataSet) / rowLength
    # 计算信息增益
    infoGain = baseGain - prop * evaluationFunc(leftDataSet) - (1 - prop) * evaluationFunc(rightDataSet)
    # 找出信息增益最大时的列索引,value,数据子集
    if (infoGain > bestGain):
    bestGain = infoGain
    bestValue = (col, value)
    bestSet = (leftDataSet, rightDataSet)
    # 结点信息
    # nodeDescription = {'impurity:%.3f'%baseGain,'sample:%d'%rowLength}
    nodeDescription = {'impurity': '%.3f' % baseGain, 'sample': '%d' % rowLength}
    # 数据行标签类别不一致,可以继续分类
    # 递归必须有终止条件
    if bestGain > 0:
    # 递归,生成左子树结点,右子树结点
    leftBranch = buildDecisionTree(bestSet[0], evaluationFunc)
    rightBranch = buildDecisionTree(bestSet[1], evaluationFunc)
    return Tree(leftBranch=leftBranch, rightBranch=rightBranch, col=bestValue[0]
    , value=bestValue[1], summary=nodeDescription, data=bestSet)
    else:
    # 数据行标签类别都相同,分类终止
    return Tree(results=calculateDiffCount(dataSet), summary=nodeDescription, data=dataSet)

    def createTree(dataSet, evaluationFunc=gini):
    # 递归建立决策树, 当gain=0,时停止回归
    # 计算基础数据集的基尼指数
    baseGain = evaluationFunc(dataSet)
    # 计算每一行的长度(也就是列总数)
    columnLength = len(dataSet[0])
    # 计算数据项总数
    rowLength = len(dataSet)
    # 初始化
    bestGain = 0.0 # 信息增益最大值
    bestValue = None # 信息增益最大时的列索引,以及划分数据集的样本值
    bestSet = None # 信息增益最大,听过样本值划分数据集后的数据子集
    # 标签列除外(最后一列),遍历每一列数据
    for col in range(columnLength - 1):
    # 获取指定列数据
    colSet = [example[col] for example in dataSet]
    # 获取指定列样本唯一值
    uniqueColSet = set(colSet)
    # 遍历指定列样本集
    for value in uniqueColSet:
    # 分割数据集
    leftDataSet, rightDataSet = splitDataSet(dataSet, value, col)
    # 计算子数据集概率,python3 "/"除号结果为小数
    prop = len(leftDataSet) / rowLength
    # 计算信息增益
    infoGain = baseGain - prop * evaluationFunc(leftDataSet) - (1 - prop) * evaluationFunc(rightDataSet)
    # 找出信息增益最大时的列索引,value,数据子集
    if (infoGain > bestGain):
    bestGain = infoGain
    bestValue = (col, value)
    bestSet = (leftDataSet, rightDataSet)

    impurity = u'%.3f' % baseGain
    sample = '%d' % rowLength
    
    if bestGain > 0:
        bestFeatLabel = u'serial:%s
    impurity:%s
    sample:%s' % (bestValue[0], impurity, sample)
        myTree = {bestFeatLabel: {}}
        myTree[bestFeatLabel][bestValue[1]] = createTree(bestSet[0], evaluationFunc)
        myTree[bestFeatLabel]['no'] = createTree(bestSet[1], evaluationFunc)
        return myTree
    else:  # 递归需要返回值
        bestFeatValue = u'%s
    impurity:%s
    sample:%s' % (str(calculateDiffCount(dataSet)), impurity, sample)
        return bestFeatValue
    

    分类测试:

    '''根据给定测试数据遍历二叉树,找到符合条件的叶子结点'''
    '''例如测试数据为[5.9,3,4.2,1.75],按照训练数据生成的决策树分类的顺序为
    第2列对应测试数据4.2 =>与决策树根结点(2)的value(3)比较,>=3则遍历左子树,否则遍历右子树,
    叶子结点就是结果'''

    def classify(data, tree):
    # 判断是否是叶子结点,是就返回叶子结点相关信息,否就继续遍历
    if tree.results != None:
    return u"%s %s" % (tree.results, tree.summary)
    else:
    branch = None
    v = data[tree.col]
    # 数值型数据
    if isinstance(v, int) or isinstance(v, float):
    if v >= tree.value:
    branch = tree.leftBranch
    else:
    branch = tree.rightBranch
    else: # 标称型数据
    if v == tree.value:
    branch = tree.leftBranch
    else:
    branch = tree.rightBranch
    return classify(data, branch)

    def loadCSV(fileName):
    def convertTypes(s):
    s = s.strip()
    try:
    return float(s) if '.' in s else int(s)
    except ValueError:
    return s

    data = np.loadtxt(fileName, dtype='str', delimiter=',')
    data = data[1:, :]
    dataSet = ([[convertTypes(item) for item in row] for row in data])
    return dataSet
    

    多数表决器

    列中相同值数量最多为结果

    def majorityCnt(classList):
    import operator
    classCounts = {}
    for value in classList:
    if (value not in classCounts.keys()):
    classCounts[value] = 0
    classCounts[value] += 1
    sortedClassCount = sorted(classCounts.items(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]

    剪枝算法(前序遍历方式:根=>左子树=>右子树)

    '''算法步骤

    1. 从二叉树的根结点出发,递归调用剪枝算法,直至左、右结点都是叶子结点
    2. 计算父节点(子结点为叶子结点)的信息增益infoGain
    3. 如果infoGain < miniGain,则选取样本多的叶子结点来取代父节点
    4. 循环1,2,3,直至遍历完整棵树
      '''''''''

    def prune(tree, miniGain, evaluationFunc=gini):
    print(u"当前结点信息:")
    print(str(tree))
    # 如果当前结点的左子树不是叶子结点,遍历左子树
    if (tree.leftBranch.results == None):
    print(u"左子树结点信息:")
    print(str(tree.leftBranch))
    prune(tree.leftBranch, miniGain, evaluationFunc)
    # 如果当前结点的右子树不是叶子结点,遍历右子树
    if (tree.rightBranch.results == None):
    print(u"右子树结点信息:")
    print(str(tree.rightBranch))
    prune(tree.rightBranch, miniGain, evaluationFunc)
    # 左子树和右子树都是叶子结点
    if (tree.leftBranch.results != None and tree.rightBranch.results != None):
    # 计算左叶子结点数据长度
    leftLen = len(tree.leftBranch.data)
    # 计算右叶子结点数据长度
    rightLen = len(tree.rightBranch.data)
    # 计算左叶子结点概率
    leftProp = leftLen / (leftLen + rightLen)
    # 计算该结点的信息增益(子类是叶子结点)
    infoGain = (evaluationFunc(tree.leftBranch.data + tree.rightBranch.data) -
    leftProp * evaluationFunc(tree.leftBranch.data) - (1 - leftProp) * evaluationFunc(
    tree.rightBranch.data))
    # 信息增益 < 给定阈值,则说明叶子结点与其父结点特征差别不大,可以剪枝
    if (infoGain < miniGain):
    # 合并左右叶子结点数据
    dataSet = tree.leftBranch.data + tree.rightBranch.data
    # 获取标签列
    classLabels = [example[-1] for example in dataSet]
    # 找到样本最多的标签值
    keyLabel = majorityCnt(classLabels)
    # 判断标签值是左右叶子结点哪一个
    if keyLabel in tree.leftBranch.results:
    # 左叶子结点取代父结点
    tree.data = tree.leftBranch.data
    tree.results = tree.leftBranch.results
    tree.summary = tree.leftBranch.summary
    else:
    # 右叶子结点取代父结点
    tree.data = tree.rightBranch.data
    tree.results = tree.rightBranch.results
    tree.summary = tree.rightBranch.summary
    tree.leftBranch = None
    tree.rightBranch = None
    ########################################################################################################

    ################################################################################################
    '''
    Created on Oct 14, 2010

    @author: Peter Harrington
    '''
    import matplotlib.pyplot as plt

    decisionNode = dict(boxstyle="sawtooth", fc="0.8")
    leafNode = dict(boxstyle="circle", fc="0.7")
    arrow_args = dict(arrowstyle="<-")

    获取树的叶子节点

    def getNumLeafs(myTree):
    numLeafs = 0
    #dict转化为list
    firstSides = list(myTree.keys())
    firstStr = firstSides[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
    #判断是否是叶子节点(通过类型判断,子类不存在,则类型为str;子类存在,则为dict)
    if type(secondDict[key]).name=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
    numLeafs += getNumLeafs(secondDict[key])
    else: numLeafs +=1
    return numLeafs

    获取树的层数

    def getTreeDepth(myTree):
    maxDepth = 0
    #dict转化为list
    firstSides = list(myTree.keys())
    firstStr = firstSides[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
    if type(secondDict[key]).name=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
    thisDepth = 1 + getTreeDepth(secondDict[key])
    else: thisDepth = 1
    if thisDepth > maxDepth: maxDepth = thisDepth
    return maxDepth

    def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
    xytext=centerPt, textcoords='axes fraction',
    va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )

    def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
    yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)

    def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on
    numLeafs = getNumLeafs(myTree) #this determines the x width of this tree
    depth = getTreeDepth(myTree)
    firstSides = list(myTree.keys())
    firstStr = firstSides[0] #the text label for this node should be this
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
    for key in secondDict.keys():
    if type(secondDict[key]).name=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
    plotTree(secondDict[key],cntrPt,str(key)) #recursion
    else: #it's a leaf node print the leaf node
    plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
    plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
    plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD

    if you do get a dictonary you know it's a tree, and the first element will be another dict

    绘制决策树 样例1

    def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) #no ticks
    #createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses
    #宽,高间距
    plotTree.totalW = float(getNumLeafs(inTree))-3
    plotTree.totalD = float(getTreeDepth(inTree))-2

    plotTree.totalW = float(getNumLeafs(inTree))

    plotTree.totalD = float(getTreeDepth(inTree))

    plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;
    plotTree(inTree, (0.95,1.0), '')
    plt.show()
    

    绘制决策树 样例2

    def createPlot1(inTree):
    fig = plt.figure(1, facecolor='white')

    fig = plt.figure(dpi=255)

    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)    #no ticks
    #createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses
    #宽,高间距
    plotTree.totalW = float(getNumLeafs(inTree))-4.5
    plotTree.totalD = float(getTreeDepth(inTree)) -3
    plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;
    plotTree(inTree, (1.0,1.0), '')
    plt.show()
    

    绘制树的根节点和叶子节点(根节点形状:长方形,叶子节点:椭圆形)

    def createPlot():

    fig = plt.figure(1, facecolor='white')

    fig.clf()

    createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses

    plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode)

    plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)

    plt.show()

    def retrieveTree(i):
    listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
    {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
    ]
    return listOfTrees[i]

    thisTree = retrieveTree(0)

    createPlot(thisTree)

    createPlot()

    myTree = retrieveTree(0)

    numLeafs =getNumLeafs(myTree)

    treeDepth =getTreeDepth(myTree)

    print(u"叶子节点数目:%d"% numLeafs)

    print(u"树深度:%d"%treeDepth)

    ##########################################################################################################
    ###################################################################################################

    -- coding: utf-8 --

    """
    Created on Wed Aug 15 14:16:59 2018

    @author: weixw
    """
    import Demo_1.myCart as mc
    from Demo_1.myCart import gini
    if name == 'main':
    import treePlotter as tp
    dataSet = mc.loadCSV("F:C盘移过来的文件dataSet.csv")
    myTree = mc.createTree(dataSet, evaluationFunc=gini)
    print(u"myTree:%s"%myTree)
    #绘制决策树
    print(u"绘制决策树:")
    tp.createPlot1(myTree)
    decisionTree = mc.buildDecisionTree(dataSet, evaluationFunc=gini)
    testData = [5.9,3,4.2,1.75]
    r = mc.classify(testData, decisionTree)
    print(u"分类后测试结果:")
    print(r)
    print()
    mc.prune(decisionTree, 0.4)
    r1 = mc.classify(testData, decisionTree)
    print(u"剪枝后测试结果:")
    print(r1)
    ################################################################################################################

  • 相关阅读:
    【树形dp】Find Metal Mineral
    【树形dp】Apple Tree
    【状压dp】Islands and Bridges
    【状压dp】Travelling
    【状压dp】Most Powerful
    【线段树】Mayor's posters
    【线段树】I Hate It
    【线段树】Atlantis
    【线段树】校门外的树
    【并查集】银河英雄传说
  • 原文地址:https://www.cnblogs.com/131415-520/p/11802578.html
Copyright © 2011-2022 走看看