ID3算法构建决策树
1 # Author Qian Chenglong 2 #label 特征的名字 dataSet n个特征+目标 3 4 5 from math import log 6 import operator 7 8 9 10 '''计算香农熵''' 11 def calcShannonEnt(dataSet): 12 numEntries=len(dataSet) 13 labelCounts={} 14 for featVec in dataSet:#将数据放入字典中,并计算字典中label出现的次数 15 currentLabel=featVec[-1] 16 if currentLabel not in labelCounts.keys(): 17 labelCounts[currentLabel]=0 18 labelCounts[currentLabel]+=1 19 shannonEnt=0.0 20 for key in labelCounts: 21 porb=float(labelCounts[key])/numEntries #每一个label出现的概率 22 shannonEnt-=porb*log(porb,2) 23 return shannonEnt 24 '''熵越高数据越混乱''' 25 26 '''按照指定特征划分数据集''' 27 def splitDataSet(dataSet,axis,value):#待划分数据集,划分数据集的特征的下标,特征的值 28 retDataSet=[] 29 for featVec in dataSet: 30 if featVec[axis]==value: 31 reducedFeatVec=featVec[:axis] #取出除划分依据用的特征以外的值 32 reducedFeatVec.extend(featVec[axis+1:]) 33 retDataSet.append(reducedFeatVec) 34 return retDataSet 35 '''把指定特征的数据取出来''' 36 37 '''遍历所有特征,选择熵最小的划分方式''' 38 def chooseBestFeatureToSplit(dataSet): 39 numFeatures=len(dataSet[0])-1 #获取属性个数,最后一列为label所以-1 40 baseEntropy=calcShannonEnt(dataSet) #数据集的原始熵 41 bestInfoGain=0.0;bestFeature=-1 42 for i in range(numFeatures): 43 featList=[example[i] for example in dataSet] #遍历当前特征的所有属性生成一个列表 i为特征下标 44 uniqueVals=set(featList) #创建一个集合,集合会删除重复的内容 45 newEntropy=0.0 46 for value in uniqueVals: #遍历当前特征的所有值 47 subDataSet=splitDataSet(dataSet,i,value) 48 prob=len(subDataSet)/float(len(dataSet)) 49 newEntropy+=prob*calcShannonEnt(subDataSet) #计算新的熵 50 infoGain=baseEntropy-newEntropy #baseEntropy-newEntropy求熵减,即信息增益 51 if(infoGain>bestInfoGain): 52 bestInfoGain=infoGain 53 bestFeature=i 54 return bestFeature 55 56 '''出现最多的目标及其次数''' 57 def majorityCnt(classList): 58 classCount={} 59 for vote in classList: 60 if vote not in classCount.keys(): 61 classCount[vote]=0 62 classCount[vote]+=1 63 sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)#reverse = True 降序 , reverse = False 升序(默认) 64 return sortedClassCount[0][0] 65 66 def createTree(dataSet,labels): 67 classList = [example[-1] for example in dataSet] #目标的列表 68 if classList.count(classList[0]) == len(classList): #所有类别都相同,即只有1个目标 69 return classList[0] #停止继续划分 70 if len(dataSet[0]) == 1: # 用完了所有特征,即只剩最后一个“目标”的时候,遍历完所有实例返回出现次数最多的类别 71 return majorityCnt(classList) 72 bestFeat = chooseBestFeatureToSplit(dataSet) 73 bestFeatLabel = labels[bestFeat] 74 myTree = {bestFeatLabel:{}} #以标签作为关键字创建树 75 del(labels[bestFeat]) #删除已使用的标签 76 featValues = [example[bestFeat] for example in dataSet] 77 uniqueVals = set(featValues) 78 for value in uniqueVals: 79 subLabels = labels[:] #copy all of labels, so trees don't mess up existing labels 80 myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels) 81 return myTree 82 83 '''获取叶节点数目''' 84 def getNumLeafs(myTree): 85 numLeafs=0 86 firstStr=myTree.keys()[0] 87 secondDict=myTree[firstStr] 88 for key in secondDict.keys(): 89 if type(secondDict[key]).__name__=='dict': 90 numLeafs+=getNumLeafs(secondDict[key]) 91 else: numLeafs+=1 92 return numLeafs 93 94 '''获取树的层数''' 95 def getTreeDepth(myTree): 96 maxDepth=0 97 firstStr=myTree.key()[0] 98 secondDict=myTree[firstStr] 99 for key in secondDict.keys(): 100 if type(secondDict[key]).__name__=='dict': 101 thisDepth=1+getTreeDepth(secondDict[key]) 102 else: thisDepth=1 103 if thisDepth>maxDepth: 104 maxDepth=thisDepth 105 return maxDepth 106 107 '''使用决策树的分类函数''' 108 def classify(inputTree,featLabels,testVec): 109 firstStr = inputTree.keys()[0] #字典中的第一个key 110 secondDict = inputTree[firstStr] #第二层字典 111 featIndex = featLabels.index(firstStr) 112 key = testVec[featIndex] 113 valueOfFeat = secondDict[key] 114 if isinstance(valueOfFeat, dict): 115 classLabel = classify(valueOfFeat, featLabels, testVec) 116 else: classLabel = valueOfFeat 117 return classLabel 118 119 '''存储树''' 120 def storeTree(inputTree,filename): 121 import pickle 122 fw = open(filename,'w') 123 pickle.dump(inputTree,fw) 124 fw.close() 125 126 '''加载树''' 127 def grabTree(filename): 128 import pickle 129 fr = open(filename) 130 return pickle.load(fr)