zoukankan      html  css  js  c++  java
  • 决策树

     决策树的一个重要任务,就是为了理解数据中蕴含的知识信息,因此决策树可以使用不熟悉的数据集合,并从中提取出一系列规则,这些机器根据数据集创建规则的过程,就是机器学习的过程。

    一、确定划分数据集的决定性特征

    信息增益:划分数据集前后信息发生的变化

    信息:l(xi)=-log2p(xi),p(xi)是选择该分类的概率

    熵(信息的期望值,表示序集无需程度的度量):H=-Σp(xi)log2p(xi)

    二、划分数据集:

    ①计算原始数据集的熵

    ②将需要划分特征的数据标签抽取出来,计算剩下的数据集的熵

    ③将两个熵做差,算出计算得信息增益

    ④对每一个数据标签重复上述操作,最终算出信息增益最大得数据标签,这个数据标签就是决定性特征。

    ⑤对剩下来得数据标签再进行挑选分类,建立决策树。

    三、存在问题

    决策树可能会产生过多的数据集划分,从而产生过度匹配数据集的问题。我们可以通过裁剪决策树,合并相邻的无法产生大量信息增益的叶节点,消除过度匹配问题

    决策树和KNN算法都是谈论具有明确分类的分类算法,朴素贝叶斯分类是一定概率的分类算法。

    import operator
    import matplotlib.pyplot as plt
    from math import log
    #计算给定数据集的香农熵
    def calcShannonEnt(dataset):
    numEntries=len(dataset)
    labelCounts={}
    for featVec in dataset:
    currentLabel=featVec[-1]
    if currentLabel not in labelCounts.keys():
    labelCounts[currentLabel]=0
    labelCounts[currentLabel]+=1
    shannonEnt=0.0
    for key in labelCounts:
    prob=float(labelCounts[key])/numEntries
    shannonEnt-=prob*log(prob,2)
    return shannonEnt
    #mydat=[[1,1,'yes'],[1,1,'yes'],[1,0,'no'],[0,1,'no'],[0,1,'no']]
    #print(calcShannonEnt(mydat))
    #按照给定特征划分数据集
    def splitDataSet(dataset,axis,value):
    retDataSet=[]
    for featVec in dataset:
    if featVec[axis] == value:
    reducedFeatVec=featVec[:axis]
    reducedFeatVec.extend(featVec[axis+1:])
    retDataSet.append(reducedFeatVec)
    return retDataSet
    def chooseBestFeatureTosplit(dataset):
    numFeatures=len(dataset[0])-1#获取特征数目-1
    baseEntropy=calcShannonEnt(dataset)#计算原始香农熵
    bestInfoGain=0.0;bestFeature=-1
    for i in range(numFeatures):
    featList=[example[i] for example in dataset]#建立dataset中的每一列标签的值
    uniqueVals=set(featList)#唯一化
    newEntropy=0.0
    for value in uniqueVals:
    subDataSet=splitDataSet(dataset,i,value)
    prob=len(subDataSet)/float(len(dataset))
    newEntropy+=prob*calcShannonEnt(subDataSet)#计算抽取了value后的香农熵之和
    infoGain=baseEntropy-newEntropy#计算信息增益
    if(infoGain>bestInfoGain):#找到最好的划分
    bestInfoGain=infoGain
    bestFeature=i
    return bestFeature
    #mydat=[[1,1,'yes'],[1,1,'yes'],[1,0,'no'],[0,1,'no'],[0,1,'no']]
    #print(chooseBestFeatureTosplit(mydat)) OUTPUT:0表明第0个特征用于划分数据集最好的特征
    #返回出现最多次的分类名称
    def majorityCnt(classList):
    classCount={}
    for vote in classList:
    if vote not in classCount.keys():classCount[vote]=0
    classCount[vote]+=1
    sortedClassCount=sorted(classCount.iteritems(),key=operator.itemgetter(1),reversed=True)
    return sortedClassCount[0][0]
    #递归构建决策树
    def createTree(dataset,labels):
    classList=[example[-1] for example in dataset]#最后一列
    if classList.count(classList[0])==len(classList):#count() 方法用于统计某个元素在列表中出现的次数,如果所有的类别都是相同的
    return classList[0]
    if len(dataset[0])==1:#只剩下一个
    return majorityCnt(classList)
    bestFeat=chooseBestFeatureTosplit(dataset)
    bestFeatLabel=labels[bestFeat]
    myTree={bestFeatLabel:{}}
    del(labels[bestFeat])
    featValues=[example[bestFeat] for example in dataset]
    uniqueVals=set(featValues)
    for value in uniqueVals:
    subLabels=labels[:]
    myTree[bestFeatLabel][value]=createTree(splitDataSet(dataset,bestFeat,value),subLabels)
    return myTree
    #绘制图形
    decisionNode=dict(boxstyle="sawtooth",fc="0.8")
    leafNode = dict(boxstyle="round4", fc="0.8")
    arrow_args=dict(arrowstyle="<-")

    def plotNode(nodeText,centerPt,parentPt,nodeType):
    createPlot.ax1.annotate(nodeText,xy=parentPt,xycoords='axes fraction',xytext=centerPt,textcoords='axes fraction',va="center",
    ha="center",bbox=nodeType,arrowprops=arrow_args)
    def getNumLeafs(myTree):
    numLeafs=0
    firstStr=list(myTree.keys())[0]
    secondDict=myTree[firstStr]
    for key in secondDict.keys():
    if type(secondDict[key]).__name__=='dict':
    numLeafs+=getNumLeafs(secondDict[key])
    else:
    numLeafs+=1
    return numLeafs
    def getTreeDepth(myTree):
    maxDepth=0
    firstStr=list(myTree.keys())[0]
    secondDict=myTree[firstStr]
    for key in secondDict.keys():
    if type(secondDict[key]).__name__=='dict':
    thisDepth=1+getTreeDepth(secondDict[key])
    else: thisDepth=1
    if thisDepth>maxDepth:maxDepth=thisDepth
    return maxDepth
    def plotMidText(cntrPt,parentPt,txtString):
    xMid=(parentPt[0]-cntrPt[0])/2.0+cntrPt[0]
    yMid=(parentPt[1]-cntrPt[1])/2.0+cntrPt[1]
    createPlot.ax1.text(xMid,yMid,txtString)
    def plotTree(myTree,parentPt,nodeTxt):
    numLeafs=getNumLeafs(myTree)
    depth=getTreeDepth(myTree)
    firstStr=list(myTree.keys())[0]
    cntrPt=(plotTree.xOff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)
    plotMidText(cntrPt,parentPt,nodeTxt)
    plotNode(firstStr,cntrPt,parentPt,decisionNode)
    secondDict=myTree[firstStr]
    plotTree.yOff=plotTree.yOff-1.0/plotTree.totalD
    for key in secondDict.keys():
    if type(secondDict[key]).__name__=='dict':
    plotTree(secondDict[key],cntrPt,str(key))
    else:
    plotTree.xOff=plotTree.xOff+1.0/plotTree.totalW
    plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)
    plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))
    plotTree.yOff=plotTree.yOff+1.0/plotTree.totalD
    def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD=float(getTreeDepth(inTree))
    plotTree.xOff=-0.5/plotTree.totalW;plotTree.yOff=1.0;
    plotTree(inTree,(0.5,1.0),'')
    plt.show()

    dataset=[[1,1,'yes'],[1,1,'yes'],[1,0,'no'],[0,1,'no'],[0,1,'no']]
    labels=['no surfacing','flippers']#前两列属性的名称,dataset第三列表示的是目标变量
    myTree=createTree(dataset,labels)
    print(myTree)
    createPlot(myTree)
    Input:传入的数据分为n+1列,前n列对应n个属性(labels),最后一列是目标变量
    Output:决策树

    ------------------------

    虽然内容还没有跑出来,问题出现在import导入的matplotlib这个包,问题分析是之前下载了annacode,和自带的python冲突了,我把annacode的环境加进来,好像也不行,代码界面报错消失了,但是一运行就会报错,有博主说是因为创建的不是python package下的python文件,测试后发现仍然不行,最终!!!!!成了!!!!

    解决anaconda与pycharm冲突详情见https://www.cnblogs.com/code-fun/p/12488711.html

    最终!

  • 相关阅读:
    两个栈实现一个队列
    DacningLinks实现
    boost::implicit_cast
    hibernate查询之Criteria实现分页方法(GROOVY语法)
    VS2015 android 设计器不能可视化问题解决。
    当Eclipse爱上SVN
    你不知道的getComputedStyle
    推荐的软件
    React之表单
    理解javascript中的Function.prototype.bind
  • 原文地址:https://www.cnblogs.com/code-fun/p/12488735.html
Copyright © 2011-2022 走看看