zoukankan      html  css  js  c++  java
  • 决策树 书上的例题

    from math import  log
    import operator
    
    def calcShannonEnt(dataSet):
        numEntries=len(dataSet);
        lableCounts={};
        for featVec in dataSet:
            currentLabel=featVec[-1];
            if currentLabel not in lableCounts.keys():
                lableCounts[currentLabel]=0;
            lableCounts[currentLabel]+=1;
        shannonEnt=0.0;
        for key in lableCounts:
            prob= float(lableCounts[key])/numEntries;
            shannonEnt-=prob* log(prob,2);
        return shannonEnt;
    
    def createDataSet():
        dataSet=[[1,1,'yes'],
                 [1,1,'yes'],
                 [1,0,'no'],
                 [0,1,'no'],
                 [0,1,'no']]
        labels=['no surfacing','flippers']
        return dataSet,labels;
    
    def splitDataSet(dataSet,axis,value):
        retDataSet=[];
        for featVec in dataSet:
            if featVec[axis]== value:
                reduceFeatVec=featVec[:axis];
                reduceFeatVec.extend(featVec[axis+1:]);
                retDataSet.append(reduceFeatVec);
        return retDataSet;
    
    def chooseBestFeatureToSplit(dataSet):
        numFeatures=len(dataSet[0])-1;
        baseEntropy=calcShannonEnt(dataSet);
        bestInfoGain=0.0;bestFeature=-1;
        for i in range(numFeatures):
            featList=[example[i] for example in dataSet];
            uniqueVals=set(featList);
            newEntropy=0.0;
            for value in uniqueVals:
                subDataSet=splitDataSet(dataSet,i,value);
                prob=len(subDataSet)/float(len(dataSet));
                newEntropy+=prob*calcShannonEnt(subDataSet);
            infoGain=baseEntropy-newEntropy;
            if(infoGain>bestInfoGain):
                bestInfoGain=infoGain;
                bestFeature=i;
        return bestFeature;
    
    def majorityCnt(classList):
        classCount={};
        for vote in classList:
            if vote not in classCount.keys():
                classCount[vote]=0;
            classCount[vote]+=1;
        sortedClassCount=sorted(classCount.iteritems(),key=operator.itemgetter(1),reverse=True);
        return sortedClassCount[0][0];
    
    def createTree(dataSet,labels):
        classList=[example[-1] for example in dataSet];
        if classList.count(classList[0])==len(classList):
            return classList[0];
        if len(dataSet[0])==1:
            return majorityCnt(classList);
        bestFeat=chooseBestFeatureToSplit(dataSet);
        bestFeatLabel=labels[bestFeat];
        myTree={bestFeatLabel:{}};
        del(labels[bestFeat]);
        featValues=[example[bestFeat] for example in dataSet];
        uniqueVals=set(featValues);
        for value in uniqueVals:
            subLabels=labels[:];
            myTree[bestFeatLabel][value]=createTree(splitDataSet(dataSet,bestFeat,value),subLabels);
        return myTree;
    View Code
  • 相关阅读:
    Android中WebView如何加载JavaScript脚本
    Android中WebView如何加载JavaScript脚本
    Android中WebView如何加载JavaScript脚本
    Android如何使用SQLlite数据库
    Android如何使用SQLlite数据库
    Android如何使用SQLlite数据库
    __declspec(dllimport)的作用
    __declspec,__cdecl,__stdcall都是什么意思?有什么作用?
    #pragma pack(push,1)与#pragma pack(1)的区别
    #pragma pack(n) 的作用
  • 原文地址:https://www.cnblogs.com/cherryMJY/p/8529369.html
Copyright © 2011-2022 走看看