zoukankan      html  css  js  c++  java
  • 决策树——ID3

    参考网址:https://www.cnblogs.com/further-further-further/p/9429257.html

    ID3算法
    ###########################################################################
    最优决策树生成

    -- coding: utf-8 --

    """
    Created on Thu Aug 2 17:09:34 2018
    决策树ID3的实现
    @author: weixw
    """
    from math import log
    import operator

    原始数据

    def createDataSet():
    dataSet = [[1, 1, 1,1,'yes'],
    [1, 1, 0,0,'yes'],
    [1, 0, 1,1,'no'],
    [0, 1, 0,1,'yes'],
    [0, 1, 1,0,'yes'],
    [1, 1, 1, 1, 'yes'],
    [1, 1, 0, 0, 'no'],
    [1, 0, 1, 1, 'no'],
    [0, 1, 0, 1, 'no'],
    [0, 1, 1, 0, 'no']]
    labels = ['no surfacing','flippers','people','day']
    return dataSet, labels

    多数表决器

    列中相同值数量最多为结果

    def majorityCnt(classList):
    classCounts = {}
    for value in classList:
    if (value not in classCounts.keys()):
    classCounts[value] = 0
    classCounts[value] += 1
    sortedClassCount = sorted(classCounts.iteritems(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]

    划分数据集

    dataSet:原始数据集

    axis:进行分割的指定列索引

    value:指定列中的值

    def splitDataSet(dataSet, axis, value):
    retDataSet = []
    for featDataVal in dataSet:
    if featDataVal[axis] == value:
    # 下面两行去除某一项指定列的值,很巧妙有没有
    reducedFeatVal = featDataVal[:axis]
    reducedFeatVal.extend(featDataVal[axis + 1:])
    retDataSet.append(reducedFeatVal)
    return retDataSet

    计算香农熵

    def calcShannonEnt(dataSet):
    # 数据集总项数
    numEntries = len(dataSet)
    # 标签计数对象初始化
    labelCounts = {}
    for featDataVal in dataSet:
    # 获取数据集每一项的最后一列的标签值
    currentLabel = featDataVal[-1]
    # 如果当前标签不在标签存储对象里,则初始化,然后计数
    if currentLabel not in labelCounts.keys():
    labelCounts[currentLabel] = 0
    labelCounts[currentLabel] += 1
    # 熵初始化
    shannonEnt = 0.0
    # 遍历标签对象,求概率,计算熵
    for key in labelCounts.keys():
    prop = labelCounts[key] / float(numEntries)
    shannonEnt -= prop * log(prop, 2)
    return shannonEnt

    选出最优特征列索引

    def chooseBestFeatureToSplit(dataSet):
    # 计算特征个数,dataSet最后一列是标签属性,不是特征量
    numFeatures = len(dataSet[0]) - 1
    # 计算初始数据香农熵
    baseEntropy = calcShannonEnt(dataSet)
    # 初始化信息增益,最优划分特征列索引
    bestInfoGain = 0.0
    bestFeatureIndex = -1
    for i in range(numFeatures):
    # 获取每一列数据
    featList = [example[i] for example in dataSet]
    # 将每一列数据去重
    uniqueVals = set(featList)
    newEntropy = 0.0
    for value in uniqueVals:
    subDataSet = splitDataSet(dataSet, i, value)
    # 计算条件概率
    prob = len(subDataSet) / float(len(dataSet))
    # 计算条件熵
    newEntropy += prob * calcShannonEnt(subDataSet)
    # 计算信息增益
    infoGain = baseEntropy - newEntropy
    if (infoGain > bestInfoGain):
    bestInfoGain = infoGain
    bestFeatureIndex = i
    return bestFeatureIndex

    决策树创建

    def createTree(dataSet, labels):
    # 获取标签属性,dataSet最后一列,区别于labels标签名称
    classList = [example[-1] for example in dataSet]
    # 树极端终止条件判断
    # 标签属性值全部相同,返回标签属性第一项值
    if classList.count(classList[0]) == len(classList):
    return classList[0]
    # 只有一个特征(1列)
    if len(dataSet[0]) == 1:
    return majorityCnt(classList)
    # 获取最优特征列索引
    bestFeatureIndex = chooseBestFeatureToSplit(dataSet)
    # 获取最优索引对应的标签名称
    bestFeatureLabel = labels[bestFeatureIndex]
    # 创建根节点
    myTree = {bestFeatureLabel: {}}
    # 去除最优索引对应的标签名,使labels标签能正确遍历
    del (labels[bestFeatureIndex])
    # 获取最优列
    bestFeature = [example[bestFeatureIndex] for example in dataSet]
    uniquesVals = set(bestFeature)
    for value in uniquesVals:
    # 子标签名称集合
    subLabels = labels[:]
    # 递归
    myTree[bestFeatureLabel][value] = createTree(splitDataSet(dataSet, bestFeatureIndex, value), subLabels)
    return myTree

    获取分类结果

    inputTree:决策树字典

    featLabels:标签列表

    testVec:测试向量 例如:简单实例下某一路径 [1,1] => yes(树干值组合,从根结点到叶子节点)

    def classify(inputTree, featLabels, testVec):
    # 获取根结点名称,将dict转化为list
    firstSide = list(inputTree.keys())
    # 根结点名称String类型
    firstStr = firstSide[0]
    # 获取根结点对应的子节点
    secondDict = inputTree[firstStr]
    # 获取根结点名称在标签列表中对应的索引
    featIndex = featLabels.index(firstStr)
    # 由索引获取向量表中的对应值
    key = testVec[featIndex]
    # 获取树干向量后的对象
    valueOfFeat = secondDict[key]
    # 判断是子结点还是叶子节点:子结点就回调分类函数,叶子结点就是分类结果
    # if type(valueOfFeat).name=='dict': 等价 if isinstance(valueOfFeat, dict):
    if isinstance(valueOfFeat, dict):
    classLabel = classify(valueOfFeat, featLabels, testVec)
    else:
    classLabel = valueOfFeat
    return classLabel

    将决策树分类器存储在磁盘中,filename一般保存为txt格式

    def storeTree(inputTree, filename):
    import pickle
    fw = open(filename, 'wb+')
    pickle.dump(inputTree, fw)
    fw.close()

    将瓷盘中的对象加载出来,这里的filename就是上面函数中的txt文件

    def grabTree(filename):
    import pickle
    fr = open(filename, 'rb')
    return pickle.load(fr)
    ################################################################################

    决策树绘制

    '''
    Created on Oct 14, 2010

    @author: Peter Harrington
    '''
    import matplotlib.pyplot as plt

    decisionNode = dict(boxstyle="sawtooth", fc="0.8")
    leafNode = dict(boxstyle="round4", fc="0.8")
    arrow_args = dict(arrowstyle="<-")

    获取树的叶子节点

    def getNumLeafs(myTree):
    numLeafs = 0
    # dict转化为list
    firstSides = list(myTree.keys())
    firstStr = firstSides[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
    # 判断是否是叶子节点(通过类型判断,子类不存在,则类型为str;子类存在,则为dict)
    if type(secondDict[
    key]).name == 'dict': # test to see if the nodes are dictonaires, if not they are leaf nodes
    numLeafs += getNumLeafs(secondDict[key])
    else:
    numLeafs += 1
    return numLeafs

    获取树的层数

    def getTreeDepth(myTree):
    maxDepth = 0
    # dict转化为list
    firstSides = list(myTree.keys())
    firstStr = firstSides[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
    if type(secondDict[
    key]).name == 'dict': # test to see if the nodes are dictonaires, if not they are leaf nodes
    thisDepth = 1 + getTreeDepth(secondDict[key])
    else:
    thisDepth = 1
    if thisDepth > maxDepth: maxDepth = thisDepth
    return maxDepth

    def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
    xytext=centerPt, textcoords='axes fraction',
    va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)

    def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
    yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)

    def plotTree(myTree, parentPt, nodeTxt): # if the first key tells you what feat was split on
    numLeafs = getNumLeafs(myTree) # this determines the x width of this tree
    depth = getTreeDepth(myTree)
    firstSides = list(myTree.keys())
    firstStr = firstSides[0] # the text label for this node should be this
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
    for key in secondDict.keys():
    if type(secondDict[
    key]).name == 'dict': # test to see if the nodes are dictonaires, if not they are leaf nodes
    plotTree(secondDict[key], cntrPt, str(key)) # recursion
    else: # it's a leaf node print the leaf node
    plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
    plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
    plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD

    if you do get a dictonary you know it's a tree, and the first element will be another dict

    绘制决策树

    def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) # no ticks
    # createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5 / plotTree.totalW;
    plotTree.yOff = 1.0;
    plotTree(inTree, (0.5, 1.0), '')
    plt.show()

    绘制树的根节点和叶子节点(根节点形状:长方形,叶子节点:椭圆形)

    def createPlot():

    fig = plt.figure(1, facecolor='white')

    fig.clf()

    createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses

    plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode)

    plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)

    plt.show()

    def retrieveTree(i):
    listOfTrees = [{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
    {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
    ]
    return listOfTrees[i]

    thisTree = retrieveTree(0)

    createPlot(thisTree)

    createPlot()

    myTree = retrieveTree(0)

    numLeafs =getNumLeafs(myTree)

    treeDepth =getTreeDepth(myTree)

    print(u"叶子节点数目:%d"% numLeafs)

    print(u"树深度:%d"%treeDepth)

    #####################################################################################

    测试代码

    -- coding: utf-8 --

    """
    Created on Fri Aug 3 19:52:10 2018

    @author: weixw
    """
    import Demo_1.myTrees as mt
    import Demo_1.treePlotter as tp

    测试

    dataSet, labels = mt.createDataSet()

    copy函数:新开辟一块内存,然后将list的所有值复制到新开辟的内存中

    labels1 = labels.copy()

    createTree函数中将labels1的值改变了,所以在分类测试时不能用labels1

    myTree = mt.createTree(dataSet,labels1)

    保存树到本地

    mt.storeTree(myTree,'myTree.txt')

    在本地磁盘获取树

    myTree = mt.grabTree('myTree.txt')
    print (u"决策树结构:%s"%myTree)

    绘制决策树

    print(u"绘制决策树:")
    tp.createPlot(myTree)
    numLeafs =tp.getNumLeafs(myTree)
    treeDepth =tp.getTreeDepth(myTree)
    print(u"叶子节点数目:%d"% numLeafs)
    print(u"树深度:%d"%treeDepth)

    测试分类 简单样本数据3列

    labelResult =mt.classify(myTree,labels,[1,1,1,0])
    print(u"[1,1] 测试结果为:%s"%labelResult)
    labelResult =mt.classify(myTree,labels,[1,0,0,0])
    print(u"[1,0] 测试结果为:%s"%labelResult)
    ############################################################################################

  • 相关阅读:
    ansible常用的一些模块
    使用jmx监控tomcat
    snmp的监控
    Selenium 入门到精通系列:六
    Selenium 入门到精通系列:五
    Selenium 入门到精通系列:四
    Selenium 入门到精通系列:三
    Selenium 入门到精通系列:二
    Selenium 入门到精通系列:一
    Python 发邮件例子
  • 原文地址:https://www.cnblogs.com/131415-520/p/11789727.html
Copyright © 2011-2022 走看看