zoukankan      html  css  js  c++  java
  • 决策树 书上的例题

    from math import  log
    import operator
    
    def calcShannonEnt(dataSet):
        numEntries=len(dataSet);
        lableCounts={};
        for featVec in dataSet:
            currentLabel=featVec[-1];
            if currentLabel not in lableCounts.keys():
                lableCounts[currentLabel]=0;
            lableCounts[currentLabel]+=1;
        shannonEnt=0.0;
        for key in lableCounts:
            prob= float(lableCounts[key])/numEntries;
            shannonEnt-=prob* log(prob,2);
        return shannonEnt;
    
    def createDataSet():
        dataSet=[[1,1,'yes'],
                 [1,1,'yes'],
                 [1,0,'no'],
                 [0,1,'no'],
                 [0,1,'no']]
        labels=['no surfacing','flippers']
        return dataSet,labels;
    
    def splitDataSet(dataSet,axis,value):
        retDataSet=[];
        for featVec in dataSet:
            if featVec[axis]== value:
                reduceFeatVec=featVec[:axis];
                reduceFeatVec.extend(featVec[axis+1:]);
                retDataSet.append(reduceFeatVec);
        return retDataSet;
    
    def chooseBestFeatureToSplit(dataSet):
        numFeatures=len(dataSet[0])-1;
        baseEntropy=calcShannonEnt(dataSet);
        bestInfoGain=0.0;bestFeature=-1;
        for i in range(numFeatures):
            featList=[example[i] for example in dataSet];
            uniqueVals=set(featList);
            newEntropy=0.0;
            for value in uniqueVals:
                subDataSet=splitDataSet(dataSet,i,value);
                prob=len(subDataSet)/float(len(dataSet));
                newEntropy+=prob*calcShannonEnt(subDataSet);
            infoGain=baseEntropy-newEntropy;
            if(infoGain>bestInfoGain):
                bestInfoGain=infoGain;
                bestFeature=i;
        return bestFeature;
    
    def majorityCnt(classList):
        classCount={};
        for vote in classList:
            if vote not in classCount.keys():
                classCount[vote]=0;
            classCount[vote]+=1;
        sortedClassCount=sorted(classCount.iteritems(),key=operator.itemgetter(1),reverse=True);
        return sortedClassCount[0][0];
    
    def createTree(dataSet,labels):
        classList=[example[-1] for example in dataSet];
        if classList.count(classList[0])==len(classList):
            return classList[0];
        if len(dataSet[0])==1:
            return majorityCnt(classList);
        bestFeat=chooseBestFeatureToSplit(dataSet);
        bestFeatLabel=labels[bestFeat];
        myTree={bestFeatLabel:{}};
        del(labels[bestFeat]);
        featValues=[example[bestFeat] for example in dataSet];
        uniqueVals=set(featValues);
        for value in uniqueVals:
            subLabels=labels[:];
            myTree[bestFeatLabel][value]=createTree(splitDataSet(dataSet,bestFeat,value),subLabels);
        return myTree;
    View Code
  • 相关阅读:
    运维实战:两台服务器http方式共享yum软件仓库
    初始化thinkphp6.0出现的问题解决
    记一次续签SSL证书导致微信小程序部分机型无法访问网站接口
    微信小程序-订阅消息验证发送值有效格式
    微信小程序分包优化
    MySQL timeout 参数详解
    mysql 事件
    springboot 远程拉取配置中心配置
    使用springboot的resttmplate请求远程服务的时候报 403问题
    for 循环 与增强的for循环 经验小结
  • 原文地址:https://www.cnblogs.com/cherryMJY/p/8529369.html
Copyright © 2011-2022 走看看