zoukankan      html  css  js  c++  java
  • 机器学习实战源码决策树的构造

     1 from math import log
     2 import operator
     3 
     4 def createDataSet():
     5     dataSet = [[1,1,"yes"],
     6                [1,1,"yes"],
     7                [1,0,"no"],
     8                [0,1,"no"],
     9                [0,1,"no"]]
    10     labels = ["no surfacing","flippers"]
    11     return dataSet,labels
    12 def calcShannonEnt(dataSet):
    13     numEntries = len(dataSet)
    14     labelCounts = {}
    15     for featVec in dataSet:
    16         currentLabel = featVec[-1]
    17         if currentLabel not in labelCounts.keys():
    18             labelCounts[currentLabel] = 0
    19         labelCounts[currentLabel] += 1
    20     shannonEnt = 0.0
    21     for key in labelCounts:
    22         prob = float(labelCounts[key]) / numEntries
    23         shannonEnt -= prob * log(prob,2)
    24     return shannonEnt
    25 def splitdataSet(dataSet,axis,value):
    26     retDataSet = []
    27     for featVec in dataSet:
    28         if featVec[axis] == value:
    29             reducedFeatVec = featVec[:axis]
    30             reducedFeatVec.extend(featVec[axis + 1:])
    31             retDataSet.append(reducedFeatVec)
    32     return retDataSet
    33 def chooseBestFeatureToSplit(dataSet):
    34     numFeatures = len(dataSet[0]) - 1
    35     baseEntropy = calcShannonEnt(dataSet)
    36     bestInfoGain = 0.0;bestFeature = -1
    37     for i in range(numFeatures):
    38         featList = [example[i] for example in dataSet]
    39         uniqueVals = set(featList)
    40         newEntropy = 0.0
    41         for value in uniqueVals:
    42             subDataSet = splitdataSet(dataSet,i,value)
    43             prob = len(subDataSet) / float(len(dataSet))
    44             newEntropy += prob * calcShannonEnt(subDataSet)
    45         infoGain = baseEntropy - newEntropy
    46         if (infoGain > bestInfoGain):
    47             bestInfoGain = infoGain
    48             bestFeature = i
    49     return bestFeature
    50 def majorityCnt(classList):
    51     classCount = {}
    52     for vote in classList:
    53         if vote not in classCount.keys():
    54             classCount[vote] = 0
    55         classCount[vote] += 1
    56     sortedClassCount = sorted(classCount.iteritems(),key=operator.itemgetter(1),reverse=True)
    57     return sortedClassCount[0][0]
    58 def createTree(dataSet,labels):
    59     classList = [example[-1] for example in dataSet]
    60     if classList.count(classList[0]) == len(classList):
    61         return classList[0]
    62     if len(dataSet[0]) == 1:
    63         return majorityCnt(classList)
    64     bestFeat = chooseBestFeatureToSplit(dataSet)
    65     bestFeatLabel = labels[bestFeat]
    66     myTree = {bestFeatLabel:{}}
    67     del(labels[bestFeat])
    68     featValues = [example[bestFeat] for example in dataSet]
    69     uniqueVals = set(featValues)
    70     for value in uniqueVals:
    71         subLabels = labels[:]
    72         myTree[bestFeatLabel][value] = createTree(splitdataSet(dataSet,bestFeat,value),subLabels)
    73     return myTree
    74 if __name__ == "__main__":
    75     myDat,labels = createDataSet()
    76     #print calcShannonEnt(myDat)
    77     #print splitdataSet(myDat,0,1)
    78     #print chooseBestFeatureToSplit(myDat)
    79     myTree = createTree(myDat,labels)
    80     print myTree
  • 相关阅读:
    zoj 1610(明天做)
    在C#中ParameterizedThreadStart和ThreadStart区别
    datagridview显示行号
    不允许对64位应用程序进行修改”的解决方法
    SQL查询表和存储过程创建修改日期
    推荐一个代码自动完成的工具AutoCode
    .net中的认证(authentication)与授权(authorization)
    SQL语句使用总结(二)
    C#/WinForm给控件加入hint文字
    sql server 2008 express 安装的时提示“重启计算机失败"
  • 原文地址:https://www.cnblogs.com/guochangyu/p/7718230.html
Copyright © 2011-2022 走看看