决策树的一个重要任务,就是为了理解数据中蕴含的知识信息,因此决策树可以使用不熟悉的数据集合,并从中提取出一系列规则,这些机器根据数据集创建规则的过程,就是机器学习的过程。
一、确定划分数据集的决定性特征
信息增益:划分数据集前后信息发生的变化
信息:l(xi)=-log2p(xi),p(xi)是选择该分类的概率
熵(信息的期望值,表示序集无需程度的度量):H=-Σp(xi)log2p(xi)
二、划分数据集:
①计算原始数据集的熵
②将需要划分特征的数据标签抽取出来,计算剩下的数据集的熵
③将两个熵做差,算出计算得信息增益
④对每一个数据标签重复上述操作,最终算出信息增益最大得数据标签,这个数据标签就是决定性特征。
⑤对剩下来得数据标签再进行挑选分类,建立决策树。
三、存在问题
决策树可能会产生过多的数据集划分,从而产生过度匹配数据集的问题。我们可以通过裁剪决策树,合并相邻的无法产生大量信息增益的叶节点,消除过度匹配问题
决策树和KNN算法都是谈论具有明确分类的分类算法,朴素贝叶斯分类是一定概率的分类算法。
import operator
import matplotlib.pyplot as plt
from math import log
#计算给定数据集的香农熵
def calcShannonEnt(dataset):
numEntries=len(dataset)
labelCounts={}
for featVec in dataset:
currentLabel=featVec[-1]
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel]=0
labelCounts[currentLabel]+=1
shannonEnt=0.0
for key in labelCounts:
prob=float(labelCounts[key])/numEntries
shannonEnt-=prob*log(prob,2)
return shannonEnt
#mydat=[[1,1,'yes'],[1,1,'yes'],[1,0,'no'],[0,1,'no'],[0,1,'no']]
#print(calcShannonEnt(mydat))
#按照给定特征划分数据集
def splitDataSet(dataset,axis,value):
retDataSet=[]
for featVec in dataset:
if featVec[axis] == value:
reducedFeatVec=featVec[:axis]
reducedFeatVec.extend(featVec[axis+1:])
retDataSet.append(reducedFeatVec)
return retDataSet
def chooseBestFeatureTosplit(dataset):
numFeatures=len(dataset[0])-1#获取特征数目-1
baseEntropy=calcShannonEnt(dataset)#计算原始香农熵
bestInfoGain=0.0;bestFeature=-1
for i in range(numFeatures):
featList=[example[i] for example in dataset]#建立dataset中的每一列标签的值
uniqueVals=set(featList)#唯一化
newEntropy=0.0
for value in uniqueVals:
subDataSet=splitDataSet(dataset,i,value)
prob=len(subDataSet)/float(len(dataset))
newEntropy+=prob*calcShannonEnt(subDataSet)#计算抽取了value后的香农熵之和
infoGain=baseEntropy-newEntropy#计算信息增益
if(infoGain>bestInfoGain):#找到最好的划分
bestInfoGain=infoGain
bestFeature=i
return bestFeature
#mydat=[[1,1,'yes'],[1,1,'yes'],[1,0,'no'],[0,1,'no'],[0,1,'no']]
#print(chooseBestFeatureTosplit(mydat)) OUTPUT:0表明第0个特征用于划分数据集最好的特征
#返回出现最多次的分类名称
def majorityCnt(classList):
classCount={}
for vote in classList:
if vote not in classCount.keys():classCount[vote]=0
classCount[vote]+=1
sortedClassCount=sorted(classCount.iteritems(),key=operator.itemgetter(1),reversed=True)
return sortedClassCount[0][0]
#递归构建决策树
def createTree(dataset,labels):
classList=[example[-1] for example in dataset]#最后一列
if classList.count(classList[0])==len(classList):#count() 方法用于统计某个元素在列表中出现的次数,如果所有的类别都是相同的
return classList[0]
if len(dataset[0])==1:#只剩下一个
return majorityCnt(classList)
bestFeat=chooseBestFeatureTosplit(dataset)
bestFeatLabel=labels[bestFeat]
myTree={bestFeatLabel:{}}
del(labels[bestFeat])
featValues=[example[bestFeat] for example in dataset]
uniqueVals=set(featValues)
for value in uniqueVals:
subLabels=labels[:]
myTree[bestFeatLabel][value]=createTree(splitDataSet(dataset,bestFeat,value),subLabels)
return myTree
#绘制图形
decisionNode=dict(boxstyle="sawtooth",fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args=dict(arrowstyle="<-")
def plotNode(nodeText,centerPt,parentPt,nodeType):
createPlot.ax1.annotate(nodeText,xy=parentPt,xycoords='axes fraction',xytext=centerPt,textcoords='axes fraction',va="center",
ha="center",bbox=nodeType,arrowprops=arrow_args)
def getNumLeafs(myTree):
numLeafs=0
firstStr=list(myTree.keys())[0]
secondDict=myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':
numLeafs+=getNumLeafs(secondDict[key])
else:
numLeafs+=1
return numLeafs
def getTreeDepth(myTree):
maxDepth=0
firstStr=list(myTree.keys())[0]
secondDict=myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':
thisDepth=1+getTreeDepth(secondDict[key])
else: thisDepth=1
if thisDepth>maxDepth:maxDepth=thisDepth
return maxDepth
def plotMidText(cntrPt,parentPt,txtString):
xMid=(parentPt[0]-cntrPt[0])/2.0+cntrPt[0]
yMid=(parentPt[1]-cntrPt[1])/2.0+cntrPt[1]
createPlot.ax1.text(xMid,yMid,txtString)
def plotTree(myTree,parentPt,nodeTxt):
numLeafs=getNumLeafs(myTree)
depth=getTreeDepth(myTree)
firstStr=list(myTree.keys())[0]
cntrPt=(plotTree.xOff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)
plotMidText(cntrPt,parentPt,nodeTxt)
plotNode(firstStr,cntrPt,parentPt,decisionNode)
secondDict=myTree[firstStr]
plotTree.yOff=plotTree.yOff-1.0/plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':
plotTree(secondDict[key],cntrPt,str(key))
else:
plotTree.xOff=plotTree.xOff+1.0/plotTree.totalW
plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)
plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))
plotTree.yOff=plotTree.yOff+1.0/plotTree.totalD
def createPlot(inTree):
fig = plt.figure(1, facecolor='white')
fig.clf()
axprops = dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD=float(getTreeDepth(inTree))
plotTree.xOff=-0.5/plotTree.totalW;plotTree.yOff=1.0;
plotTree(inTree,(0.5,1.0),'')
plt.show()
dataset=[[1,1,'yes'],[1,1,'yes'],[1,0,'no'],[0,1,'no'],[0,1,'no']]
labels=['no surfacing','flippers']#前两列属性的名称,dataset第三列表示的是目标变量
myTree=createTree(dataset,labels)
print(myTree)
createPlot(myTree)
Input:传入的数据分为n+1列,前n列对应n个属性(labels),最后一列是目标变量
Output:决策树
------------------------
虽然内容还没有跑出来,问题出现在import导入的matplotlib这个包,问题分析是之前下载了annacode,和自带的python冲突了,我把annacode的环境加进来,好像也不行,代码界面报错消失了,但是一运行就会报错,有博主说是因为创建的不是python package下的python文件,测试后发现仍然不行,最终!!!!!成了!!!!
解决anaconda与pycharm冲突详情见https://www.cnblogs.com/code-fun/p/12488711.html
最终!