zoukankan      html  css  js  c++  java
  • 决策树 (decision tree)

    内容学习于 ApacheCN github 

    定义:

      分类决策树模型是一种描述对实例进行分类的树形结构。决策树由结点(node)和有向边(directed edge)组成。结点有两种类型:内部结点(internal node)和叶结点(leaf node)。内部结点表示一个特征或属性(features),叶结点表示一个类(labels)。

      用决策树对需要测试的实例进行分类:从根节点开始,对实例的某一特征进行测试,根据测试结果,将实例分配到其子结点;这时,每一个子结点对应着该特征的一个取值。如此递归地对实例进行测试并分配,直至达到叶结点。最后将实例分配到叶结点的类中。

    原理: 

    决策树 须知概念

    信息熵 & 信息增益

    熵(entropy): 熵指的是体系的混乱的程度,在不同的学科中也有引申出的更为具体的定义,是各领域十分重要的参量。

    信息论(information theory)中的熵(香农熵): 是一种信息的度量方式,表示信息的混乱程度,也就是说:信息越有序,信息熵越低。例如:火柴有序放在火柴盒里,熵值很低;相反,熵值很高。

    信息增益(information gain): 在划分数据集前后信息发生的变化称为信息增益。

     

    决策树算法特点

    优点:计算复杂度不高,输出结果易于理解,数据有缺失也能跑,可以处理不相关特征。
    缺点:容易过拟合。
    适用数据类型:数值型和标称型。

    如何构造一个决策树?

    def createBranch():
    '''
    此处运用了迭代的思想。 感兴趣可以搜索 迭代 recursion, 甚至是 dynamic programing。
    '''
    检测数据集中的所有数据的分类标签是否相同:
    If so return 类标签
    Else:
    寻找划分数据集的最好特征(划分之后信息熵最小,也就是信息增益最大的特征)
    划分数据集
    创建分支节点
    for 每个划分的子集
    调用函数 createBranch (创建分支的函数)并增加返回结果到分支节点中
    return 分支节点

     1 from math import log
     2 
     3 def createDataSet():
     4     dataSet = [
     5         [1, 1, 'yes'],
     6         [1, 1, 'yes'],
     7         [1, 0, 'no'],
     8         [0, 1, 'no'],
     9         [0, 1, 'no'],
    10     ]
    11     labels = ['no surfacing', 'flippers']
    12     return dataSet, labels
    13 
    14 
    15 def calcShannonEnt(dataSet):
    16     # 参与计算的数据量
    17     numEntries = len(dataSet)
    18     # 分类标签出现的次数
    19     labelCounts = {}
    20     for foo in dataSet:
    21         currentLabel = foo[-1]
    22         # 分类写入字典,不存在则创建,并记录当前类别的次数
    23         if currentLabel not in labelCounts.keys():
    24             labelCounts[currentLabel] = 0
    25         labelCounts[currentLabel] += 1
    26     # 对于 label 标签的占比,求出 label 标签的香农熵
    27     shannonEnt = 0.0
    28     for key in labelCounts:
    29         # 计算每个标签出现的频率
    30         prob = labelCounts[key] / numEntries
    31         # 计算香农熵,以 2 为底求对数
    32         shannonEnt -= prob * log(prob, 2)
    33     return shannonEnt
    34 
    35 def splitDataSet(dataSet, index, value):
    36     retDataSet = []
    37     for featVec in dataSet:
    38         # 除去 index 列为 value 的数据集
    39         if featVec[index] == value:
    40             # 取 index 列前的数据列
    41             reducedFeatVec = featVec[:index]
    42             # 取 index 列后的数据列
    43             reducedFeatVec.extend(featVec[index + 1:])
    44             retDataSet.append(reducedFeatVec)
    45     return retDataSet
    46 
    47 def chooseBestFeatureToSplit(dataSet):
    48     # 有多少列的特征 Feature ,最后一列是类 label
    49     numFeature = len(dataSet) - 1
    50     # 数据集的原始信息熵
    51     baseEntropy = calcShannonEnt(dataSet)
    52     # 记录最优的信息增益和最优的特征 Feature 编号
    53     bestInfoGain, bestFeature = 0.0, -1
    54     for i in range(numFeature):
    55         # 获取对应特征 Feature 下的所有数据
    56         featList = [example[i] for example in dataSet]
    57         # 对特征列表进行去重
    58         uniqueVals = set(featList)
    59         # 创建一个临时信息熵
    60         tempEntropy = 0.0
    61         # 遍历某一列 value 集合计算该列的信息熵
    62         for value in uniqueVals:
    63             # 取去除第 i 列值为 value  的子集
    64             subDataSet = splitDataSet(dataSet, i, value)
    65             # 概率
    66             prob = len(subDataSet) / len(dataSet)
    67             # 计算信息熵
    68             tempEntropy += prob * calcShannonEnt(subDataSet)
    69         infoGain = baseEntropy - tempEntropy
    70         if infoGain > bestInfoGain:
    71             bestInfoGain = infoGain
    72             bestFeature = i
    73     return bestFeature
    决策树部分代码
  • 相关阅读:
    go入门4---数据
    hibernate之关联关系一对多
    hibernate的主键生成策略
    hibernate的入门
    struts--CRUD优化(图片上传)
    struts2--CRUD
    struts2--入
    Maven环境搭建
    EasyUI--增删改查
    easyui--权限管理
  • 原文地址:https://www.cnblogs.com/xsmile/p/10696175.html
Copyright © 2011-2022 走看看