tree.py
from math import log import operator def calcShannonEnt(dataSet): numEntries = len(dataSet) labelCounts = {} for featVec in dataSet: currentLabel = featVec[-1] if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0 labelCounts[currentLabel] += 1 shannonEnt = 0.0 for key in labelCounts: prob = float(labelCounts[key]) / numEntries shannonEnt -= prob * log(prob, 2) return shannonEnt def createDataSet(): dataSet = [ [1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no'] ] labels = ['no surfacing', 'flippers'] return dataSet, labels def splitDataSet(dataSet, axis, value): retDataSet = [] for featVec in dataSet: if featVec[axis] == value: reducedFeatVec = featVec[ : axis] reducedFeatVec.extend(featVec[axis + 1 : ]) retDataSet.append(reducedFeatVec) return retDataSet def step1(): myDat, labels = createDataSet() print myDat shannonRes = calcShannonEnt(myDat) print shannonRes def step2(): a = [1, 2, 3] b = [4, 5, 6] a.append(b) print a a = [1, 2, 3] a.extend(b) print a def step3(): myDat, labels = createDataSet() print splitDataSet(myDat, 0, 1) print splitDataSet(myDat, 0, 0) def chooseBestFeatToSplit(dataSet): numFeatures = len(dataSet[0]) - 1 baseEntropy = calcShannonEnt(dataSet) bestInfoGain = 0.0; bestFeature =-1 for i in range(numFeatures): featList = [example[i] for example in dataSet] uniqueVals = set(featList) newEntropy = 0.0 for value in uniqueVals: subDataSet = splitDataSet(dataSet, i, value) prob = len(subDataSet) / float(len(dataSet)) newEntropy += prob * calcShannonEnt(subDataSet) infoGain = baseEntropy - newEntropy if(infoGain > bestInfoGain): bestInfoGain = infoGain bestFeature = i return bestFeature def step4(): myDat, labels = createDataSet() res = chooseBestFeatToSplit(myDat) print res print myDat def majorityCnt(classList): classCount = {} for vote in classList: if vote not in classCount.keys(): classCount[vote] = 0 classCount[vote] += 1 sortedClassCount = sorted(classCount.iteritems(), key = operator.itemgetter(1), reverse=True) return sortedClassCount[0][0] def createTree(dataSet, labels): classList = [example[-1] for example in dataSet] if classList.count(classList[0]) == len(classList): return classList[0] if len(dataSet[0]) == 1: return majorityCnt(classList) bestFeat = chooseBestFeatToSplit(dataSet) bestFeatLabel = labels[bestFeat] myTree = {bestFeatLabel:{}} del(labels[bestFeat]) featValues = [example[bestFeat] for example in dataSet] uniqueVals = set(featValues) for value in uniqueVals: subLabels = labels[:] myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels) return myTree def step5(): myDat, labels = createDataSet() print myDat print labels myTree = createTree(myDat, labels) print myTree if __name__ == '__main__': step5()