决策树的工作过程就是一个自顶向下不断推断分解,逐步缩小待猜测事物范围的过程,很显然其关键是如何构造一个满足需求的树结构。
# 决策树的构造
在构造决策树时,首先需要解决的问题便是,当前数据集上哪个特征在划分数据分类时起决定作用。为了找到决定性的特征,我们需要评估每个特征。选出最优的特征后,原始数据集就会被划分为几个数据子集,这些数据子集会分布在第一个决策点的所有分支上。这时候各分支会出现两种情况:一种是分支下的数据属于同一类型,则说明当前分支已经正确地划分数据分类,无需进一步对数据集进行分割;另一种是子集内数据不属于同一类型,则需要重复上面的过程,在剩余特征中选择一个划分更细的数据子集。如此递归构造出整棵树。
创建分支的伪代码函数createBranch( ) 如下所示:
检测数据集中的每个子项是否属于同一分类:
if so return 类标签;
Else
寻找划分数据集的最好特征
划分数据集
创建分支节点
for 每个划分的子集
调用函数createBranch并增加返回结果到分支节点中
return 分支节点
## 信息增益
划分数据集的大原则是:将无序的数据变得更加有序。组织杂乱无章数据的一种方法就是使用信息论度量信息,信息论 是量化处理信息的分支科学。我们可以在划分数据之前或之后使用信息论量化度量信息的内容。在划分数据集之前之后信息发生的变化称为信息增益,知道如何计算信息增益,我们就可以计算每个特征值划分数据集获得的信息增益,获得信息增益最高的特征就是最好的选择。
- 符号的信息定义为:
其中p(x)是选择该分类的概率。
- 熵定义为信息的期望值,为了计算熵,需要计算所有类别所有可能值包含的信息期望值,通过下面公式得到:
根据上面的公式,我们很容易编程实现求出给定数据集的熵作为后文选择划分数据集提供依据。
def calcShannonEnt(dataSet):
'''
计算给定数据集的熵
:param dataSet: 待计算的数据集
:return: 熵
'''
numEntries = len(dataSet)
# 为所有可能分类创建字典,即{label:cnt}
labelCounts = {}
for featVec in dataSet:
currentLabel = featVec[-1]
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
# 计算香农熵
shannonEnt = 0.0
for key in labelCounts:
prob = float(labelCounts[key]) / numEntries
shannonEnt -= prob * log(prob, 2)
return shannonEnt
## 划分数据集
上一节有了度量数据集无序程度的评价标准,我们还需要划分数据集,度量划分数据集的熵,以便判断当前是否正确地划分了数据集。我们需要对每个特征划分数据集的结果计算一次信息熵,然后根据信息增益大小确定最好的划分方式。
def splitDataSet(dataSet, axis, value):
'''
按照给定特征划分数据集
:param dataSet: 待分配的数据集
:param axis: 划分数据集的特征
:param value: 划分特征的值
:return: 返回一个挖去axis,且对应值等于value的数据集
'''
# 创建新的list对象,存储子集
retDataSet = []
for featVec in dataSet:
if featVec[axis] == value:
# 拼接数据, list.extend(seq)横向拼接
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis+1:])
retDataSet.append(reducedFeatVec)
return retDataSet
def chooseBestFeatureToSplit(dataSet):
'''
选择数据集的最佳分割方式
:param dataSet: 待划分的数据集
:return: 最好的划分特征
'''
numFeatures = len(dataSet[0]) - 1
baseEntropy = calcShannonEnt(dataSet)
bestInfoGain = 0.0
bestFeature = -1
for i in range(numFeatures):
# 创建各个特征划分下的值
featList = [example[i] for example in dataSet]
uniqueVals = set(featList)
# 计算每种划分方式的信息熵
newEntropy = 0.0
for value in uniqueVals:
subDataSet = splitDataSet(dataSet, i, value)
prob = len(subDataSet) / float(len(dataSet))
newEntropy += prob * calcShannonEnt(subDataSet)
# 计算信息增益,选择最好的那个
infoGain = baseEntropy - newEntropy
if infoGain > bestInfoGain:
bestInfoGain = infoGain
bestFeature = i
return bestFeature
## 递归构建决策树
有了计算熵和有效划分数据集的函数,接下来就是将函数组合构建决策树。回顾一下算法工作原理:原始数据集,然后基于最好的属性值划分数据集,由于特征值可能多于两个,因此可能存在大于 两个分支的数据集划分。第一次划分之后,数据将被向下传递到树分支的下一个节点,在这个节 点上,我们可以再次划分数据。因此我们可以采用递归的原则处理数据集。
递归结束的条件是:程序遍历完所有划分数据集的属性,或者每个分支下的所有实例都具有 相同的分类。如果所有实例具有相同的分类,则得到一个叶子节点或者终止块。
def majorityCnt(classList):
'''
多数表决的方式决定分类
:param classList: 标签列表
:return: 返回表决结果
'''
classCount = {}
for vote in classList:
if vote not in classCount.keys():
classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.items(), key=lambda x:x[1], reverse=True)
return sortedClassCount[0][0]
def createTree(dataSet, labels):
'''
决策树生成算法主程序
:param dataSet: 原始数据集
:param labels: 标签列表
:return: 决策树
'''
# 包含了数据集的所有类标签
classList = [example[-1] for example in dataSet]
# 类别完全相同则停止继续划分
if classList.count(classList[0]) == len(classList):
return classList[0]
# 遍历到最后一个特征时,返回出现次数最多的类别
if len(dataSet[0]) == 1:
return majorityCnt(classList)
# 寻找当前最佳划分特征
bestFeat = chooseBestFeatureToSplit(dataSet)
bestFeatLabel = labels[bestFeat]
# 开始构建决策树
myTree = {bestFeatLabel:{}}
del(labels[bestFeat])
# 获得剩下的特征
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
# 剩下特征递归生成决策树
for value in uniqueVals:
subLabels = labels[:]
myTree[bestFeatLabel][value] = createTree(splitDataSet(
dataSet, bestFeat, value), subLabels)
return myTree
## 决策树的存储和读取
为了节省计算时间,可以使用python模块pickle序列化对象。序列化对象可以在磁 盘上保存对象,并在需要的时候读取出来。任何对象都可以执行序列化操作,字典对象也不例外。
def storeTree(inputTree, filename):
'''
存储序列化对象
:param inputTree:
:param filename:
:return:
'''
import pickle
fw = open(filename, 'wb')
pickle.dump(inputTree, fw)
fw.close()
def grabTree(filename):
'''
读取序列化对象
:param filename:
:return:
'''
import pickle
fr = open(filename, 'rb')
return pickle.load(fr)
if __name__ == '__main__':
mat, labels = createDataSet()
myTree = createTree(mat, labels)
storeTree(myTree, 'obj.txt')
print(grabTree('obj.txt'))