zoukankan      html  css  js  c++  java
  • 决策树

    决策树对实例进行分类的树形结构,由节点和有向边组成。其实很像平时画的流程图。

    学习决策树之前要搞懂几个概念:

    熵:表示随机变量不确定性的度量,定义:H(p)=-

    信息增益:集合D的经验熵与特征A条件下D的经验条件熵H(D/A)之差(公式省略,自行查找)

    信息增益比:信息增益g(D,A)与训练数据集D关于特征A的值得熵HA(D)之比(公式省略)

    基尼系数:(公式省略)

    以上几个公式要牢记并学会推到。

    具体计算过程:

    ID3算法:寻找信息增益最大的特征

    C4.5 寻找信息增益比最大的特征

    另外有CART树算法,使用基尼系数来确定特征选择部分。

    树的剪枝分为前剪枝和后剪枝。目的为防止过拟合。

    前剪枝即为在树的构建过程中,若新增加的分支未使得准确率增加,则不进行该分支操作。

    后剪枝即构建完决策树之后,从最底部的分支开始,若去掉该分支,分类准确性增加,则去掉该分支,否则保留。

    def chooseBestFeatureToSplit(dataSet):#关键部分代码为寻找最优特征,寻找使得信息增益最大的特征
        numFeatures = len(dataSet[0]) - 1  # the last column is used for the labels
        baseEntropy = calcShannonEnt(dataSet)
        bestInfoGain = 0.0;
        bestFeature = -1
        for i in range(numFeatures):  # iterate over all the features
            featList = [example[i] for example in dataSet]  # create a list of all the examples of this feature
            uniqueVals = set(featList)  # get a set of unique values
            newEntropy = 0.0
            for value in uniqueVals:
                subDataSet = splitDataSet(dataSet, i, value)
                prob = len(subDataSet) / float(len(dataSet))
                newEntropy += prob * calcShannonEnt(subDataSet)
            infoGain = baseEntropy - newEntropy  # calculate the info gain; ie reduction in entropy
            if (infoGain > bestInfoGain):  # compare this to the best gain so far
                bestInfoGain = infoGain  # if better than current best, set to best
                bestFeature = i
        return bestFeature  # returns an integer
    

     

    def createTree(dataSet, labels): #构建决策树,寻找最优特征,做为根节点,之后根据特征分为几类,进行递归,最终形成完整决策树
        classList = [example[-1] for example in dataSet]
        if classList.count(classList[0]) == len(classList):
            return classList[0]  # stop splitting when all of the classes are equal
        if len(dataSet[0]) == 1:  # stop splitting when there are no more features in dataSet
            return majorityCnt(classList)
        bestFeat = chooseBestFeatureToSplit(dataSet)
        bestFeatLabel = labels[bestFeat]
        myTree = {bestFeatLabel: {}}
        del (labels[bestFeat])
        featValues = [example[bestFeat] for example in dataSet]
        uniqueVals = set(featValues)
        for value in uniqueVals:
            subLabels = labels[:]  # copy all of labels, so trees don't mess up existing labels
            myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
        return myTree
    

      

     对应SKlearn中的API接口:

    DecisionTreeClassifier 分类,注意此方法只有前剪枝选项。

    DecisionTreeRegressor   回归。

  • 相关阅读:
    牛客网编程练习(2018校招真题编程题汇总)------字符串价值
    牛客网编程练习(2018校招真题编程题汇总)------排序
    牛客网编程练习(2018校招真题编程题汇总)------回文素数
    牛客网编程练习(2018校招真题编程题汇总)------判断题
    牛客网编程练习(2018校招真题编程题汇总)------删除重复字符
    mysql5.7出现大量too many connections及too many open files错误,且配置最大连接数未生效
    commons-lang3之StringUtils
    commons-lang3 事件机制 <EventListenerSupport>
    springboot文件上传下载简单使用
    Redis5.0.4复制
  • 原文地址:https://www.cnblogs.com/the-home-of-123/p/9174426.html
Copyright © 2011-2022 走看看