1 from math import log 2 import operator 3 4 def createDataSet(): 5 dataSet = [[1,1,"yes"], 6 [1,1,"yes"], 7 [1,0,"no"], 8 [0,1,"no"], 9 [0,1,"no"]] 10 labels = ["no surfacing","flippers"] 11 return dataSet,labels 12 def calcShannonEnt(dataSet): 13 numEntries = len(dataSet) 14 labelCounts = {} 15 for featVec in dataSet: 16 currentLabel = featVec[-1] 17 if currentLabel not in labelCounts.keys(): 18 labelCounts[currentLabel] = 0 19 labelCounts[currentLabel] += 1 20 shannonEnt = 0.0 21 for key in labelCounts: 22 prob = float(labelCounts[key]) / numEntries 23 shannonEnt -= prob * log(prob,2) 24 return shannonEnt 25 def splitdataSet(dataSet,axis,value): 26 retDataSet = [] 27 for featVec in dataSet: 28 if featVec[axis] == value: 29 reducedFeatVec = featVec[:axis] 30 reducedFeatVec.extend(featVec[axis + 1:]) 31 retDataSet.append(reducedFeatVec) 32 return retDataSet 33 def chooseBestFeatureToSplit(dataSet): 34 numFeatures = len(dataSet[0]) - 1 35 baseEntropy = calcShannonEnt(dataSet) 36 bestInfoGain = 0.0;bestFeature = -1 37 for i in range(numFeatures): 38 featList = [example[i] for example in dataSet] 39 uniqueVals = set(featList) 40 newEntropy = 0.0 41 for value in uniqueVals: 42 subDataSet = splitdataSet(dataSet,i,value) 43 prob = len(subDataSet) / float(len(dataSet)) 44 newEntropy += prob * calcShannonEnt(subDataSet) 45 infoGain = baseEntropy - newEntropy 46 if (infoGain > bestInfoGain): 47 bestInfoGain = infoGain 48 bestFeature = i 49 return bestFeature 50 def majorityCnt(classList): 51 classCount = {} 52 for vote in classList: 53 if vote not in classCount.keys(): 54 classCount[vote] = 0 55 classCount[vote] += 1 56 sortedClassCount = sorted(classCount.iteritems(),key=operator.itemgetter(1),reverse=True) 57 return sortedClassCount[0][0] 58 def createTree(dataSet,labels): 59 classList = [example[-1] for example in dataSet] 60 if classList.count(classList[0]) == len(classList): 61 return classList[0] 62 if len(dataSet[0]) == 1: 63 return majorityCnt(classList) 64 bestFeat = chooseBestFeatureToSplit(dataSet) 65 bestFeatLabel = labels[bestFeat] 66 myTree = {bestFeatLabel:{}} 67 del(labels[bestFeat]) 68 featValues = [example[bestFeat] for example in dataSet] 69 uniqueVals = set(featValues) 70 for value in uniqueVals: 71 subLabels = labels[:] 72 myTree[bestFeatLabel][value] = createTree(splitdataSet(dataSet,bestFeat,value),subLabels) 73 return myTree 74 if __name__ == "__main__": 75 myDat,labels = createDataSet() 76 #print calcShannonEnt(myDat) 77 #print splitdataSet(myDat,0,1) 78 #print chooseBestFeatureToSplit(myDat) 79 myTree = createTree(myDat,labels) 80 print myTree