zoukankan      html  css  js  c++  java
  • 机器学习实战源码决策树的构造

     1 from math import log
     2 import operator
     3 
     4 def createDataSet():
     5     dataSet = [[1,1,"yes"],
     6                [1,1,"yes"],
     7                [1,0,"no"],
     8                [0,1,"no"],
     9                [0,1,"no"]]
    10     labels = ["no surfacing","flippers"]
    11     return dataSet,labels
    12 def calcShannonEnt(dataSet):
    13     numEntries = len(dataSet)
    14     labelCounts = {}
    15     for featVec in dataSet:
    16         currentLabel = featVec[-1]
    17         if currentLabel not in labelCounts.keys():
    18             labelCounts[currentLabel] = 0
    19         labelCounts[currentLabel] += 1
    20     shannonEnt = 0.0
    21     for key in labelCounts:
    22         prob = float(labelCounts[key]) / numEntries
    23         shannonEnt -= prob * log(prob,2)
    24     return shannonEnt
    25 def splitdataSet(dataSet,axis,value):
    26     retDataSet = []
    27     for featVec in dataSet:
    28         if featVec[axis] == value:
    29             reducedFeatVec = featVec[:axis]
    30             reducedFeatVec.extend(featVec[axis + 1:])
    31             retDataSet.append(reducedFeatVec)
    32     return retDataSet
    33 def chooseBestFeatureToSplit(dataSet):
    34     numFeatures = len(dataSet[0]) - 1
    35     baseEntropy = calcShannonEnt(dataSet)
    36     bestInfoGain = 0.0;bestFeature = -1
    37     for i in range(numFeatures):
    38         featList = [example[i] for example in dataSet]
    39         uniqueVals = set(featList)
    40         newEntropy = 0.0
    41         for value in uniqueVals:
    42             subDataSet = splitdataSet(dataSet,i,value)
    43             prob = len(subDataSet) / float(len(dataSet))
    44             newEntropy += prob * calcShannonEnt(subDataSet)
    45         infoGain = baseEntropy - newEntropy
    46         if (infoGain > bestInfoGain):
    47             bestInfoGain = infoGain
    48             bestFeature = i
    49     return bestFeature
    50 def majorityCnt(classList):
    51     classCount = {}
    52     for vote in classList:
    53         if vote not in classCount.keys():
    54             classCount[vote] = 0
    55         classCount[vote] += 1
    56     sortedClassCount = sorted(classCount.iteritems(),key=operator.itemgetter(1),reverse=True)
    57     return sortedClassCount[0][0]
    58 def createTree(dataSet,labels):
    59     classList = [example[-1] for example in dataSet]
    60     if classList.count(classList[0]) == len(classList):
    61         return classList[0]
    62     if len(dataSet[0]) == 1:
    63         return majorityCnt(classList)
    64     bestFeat = chooseBestFeatureToSplit(dataSet)
    65     bestFeatLabel = labels[bestFeat]
    66     myTree = {bestFeatLabel:{}}
    67     del(labels[bestFeat])
    68     featValues = [example[bestFeat] for example in dataSet]
    69     uniqueVals = set(featValues)
    70     for value in uniqueVals:
    71         subLabels = labels[:]
    72         myTree[bestFeatLabel][value] = createTree(splitdataSet(dataSet,bestFeat,value),subLabels)
    73     return myTree
    74 if __name__ == "__main__":
    75     myDat,labels = createDataSet()
    76     #print calcShannonEnt(myDat)
    77     #print splitdataSet(myDat,0,1)
    78     #print chooseBestFeatureToSplit(myDat)
    79     myTree = createTree(myDat,labels)
    80     print myTree
  • 相关阅读:
    js保留几位小数
    IE的卸载之路(折腾1个多月,记录下。。)
    百度map
    鼠标滑轮事件监听,兼容各类浏览器
    sql server分页存储过程
    echarts(3.0)的基本使用(标签式导入)
    datagrid加分组后的效果
    python文件操作
    python求100以内素数
    python 三元运算符
  • 原文地址:https://www.cnblogs.com/guochangyu/p/7718230.html
Copyright © 2011-2022 走看看