zoukankan      html  css  js  c++  java
  • 机器学习(周志华老师) 决策树习题4.3

    建立决策树

    import numpy as np 
    import pandas as pd 
    import operator
    from PlotTree import createPlot #调用绘图子程序
    
    #计算数据集dataSe香农熵
    def calcShannonEnt(dataSet):
        numEntries = len(dataSet)
        labelCounts = {}
        for featVec in dataSet:
            currentLabel=featVec[-1]
            if currentLabel not in labelCounts.keys():
                labelCounts[currentLabel]=0
            labelCounts[currentLabel]+=1
        shannonEnt=0.0
        for key in labelCounts:
            prob = float(labelCounts[key])/numEntries
            shannonEnt-=prob*np.log2(prob)
        return shannonEnt
        
    #求离散属性取某一值时数据集
    def splitDataSet(dataSet,axis,value):
        subDataSet=[]
        for example in dataSet:
            reducedExample = []
            if example[axis] == value:
                reducedExample = example[:axis]
                reducedExample.extend(example[axis+1:])
                subDataSet.append(reducedExample)
        return subDataSet
    
    #分割连续数据集,根据direction参数确定方向
    def splitContinuousDataSet(dataSet,axis,value,direction):
        reDataSet = []
        for example in dataSet:
            subDataSet=[]
            if direction == 1:
                if example[axis] <= value:
                    subDataSet = example[:axis]
                    subDataSet.extend(example[axis+1:])
                    reDataSet.append(subDataSet)
            else:
                if example[axis] > value:
                    subDataSet = example[:axis]
                    subDataSet.extend(example[axis+1:])
                    reDataSet.append(subDataSet)
        return reDataSet
                
    
    
    #求最佳划分属性
    bestSplitDict={}
    def chooseBestFeatureToSplit(dataSet,labels):
        baseEntorpy = calcShannonEnt(dataSet)
        bestInfoGain = 0.0
        bestFeature = -1
        
        for i in range(len(labels)):
            featureList = [example[i] for example in dataSet]
            if type(featureList[0]).__name__ == 'float' or type(featureList[0]).__name__ == 'int':
                sortedFeatureList = sorted(featureList)
                splitList = []
                for j in range(len(sortedFeatureList)-1):
                    splitList.append((sortedFeatureList[j]+sortedFeatureList[j+1])/2.0)
                bestCotinuEnt = 10000
                bestSplit = 0
                #print(splitList,end='
    ')
                for j in range(len(splitList)):
                    featureEnt = 0.0
                    #print(dataSet,end='
    ')
                    dataSet1=splitContinuousDataSet(dataSet,i,splitList[j],1)
                    #print(dataSet1,end='
    ')
                    prob1 = len(dataSet1)/float(len(dataSet))
                    featureEnt += prob1*calcShannonEnt(dataSet1)
                    dataSet0=splitContinuousDataSet(dataSet,i,splitList[j],0)
                    prob0 = len(dataSet0)/float(len(dataSet))
                    featureEnt += prob0*calcShannonEnt(dataSet0)
                    if featureEnt < bestCotinuEnt:
                        bestCotinuEnt = featureEnt
                        bestSplit = splitList[j]
                bestSplitDict[labels[i]] = bestSplit
                infoGain = baseEntorpy - bestCotinuEnt
            #对于连续属性求信息熵
            else:
                uniqueValue = set(featureList)
                featureEnt = 0.0
                for value in uniqueValue:
                    subDataSet = splitDataSet(dataSet,i,value)
                    prob= len(subDataSet)/float(len(dataSet))
                    featureEnt += prob*calcShannonEnt(subDataSet)
                infoGain = baseEntorpy - featureEnt
            #
            if infoGain > bestInfoGain:
                bestFeature = i
                bestInfoGain = infoGain
    
            #若最佳划分属性为连续值,将其以划分点为界,进行二至处理
            if type(dataSet[0][bestFeature]).__name__ == 'float' or 
                type(dataSet[0][bestFeature]).__name__ == 'int':
                bestSplitValue = bestSplitDict[labels[bestFeature]]
                #labels[bestFeature] = labels[bestFeature] + '<=' + str(bestSplitValue)
                for i in range(len(dataSet)):
                    if dataSet[i][bestFeature] <= bestSplitValue:
                        dataSet[i][bestFeature] = 1
                    else:
                        dataSet[i][bestFeature] = 0
        return bestFeature
    
    #当不能划分时,投票选出分类
    def majorityCnt(classList):
        classCount={}
        for i in classList:
            if i not in classCount.keys():
                classCount[i]=0
            classCount[i]+=1
        return max(classCount,key=classCount.get)
    
    #生成决策树
    def createTree(dataSet,labels,data_full,labels_full):
        classList = [example[-1] for example in dataSet]
        if classList.count(classList[0]) == len(classList):
            return classList[0]
        if len(dataSet[0]) == 1:
            return majorityCnt(classList)
        bestFeature = chooseBestFeatureToSplit(dataSet,labels)
        
        bestFeatureLabel = labels[bestFeature]
        myTree = {bestFeatureLabel:{}}
        
        featureVals = [example[bestFeature] for example in dataSet]
        uniqueVals = set(featureVals)
        if type(dataSet[0][bestFeature]).__name__ == 'str':
            currentLabel = labels_full.index(labels[bestFeature])
            featureVals_full = [example[currentLabel] for example in data_full]
            uniqueVals_full = set(featureVals_full)
        else:
            bestSplitValue = bestSplitDict[labels[bestFeature]]
            bestFeatureLabel = labels[bestFeature] + '<=' + str(bestSplitValue)
            myTree = {bestFeatureLabel:{}}
        del(labels[bestFeature])
        for value in uniqueVals:
            sublabels = labels[:]
            if type(dataSet[0][bestFeature]).__name__ == 'str':
                uniqueVals_full.remove(value)
            myTree[bestFeatureLabel][value] = createTree(splitDataSet 
                (dataSet,bestFeature,value), sublabels,data_full,labels_full)
        if type(dataSet[0][bestFeature]) == 'str':
            for val in uniqueVals_full:
                myTree[bestFeatureLabel][val] = majorityCnt(classList)
        return myTree
    
    #main function
    df = pd.read_csv('WaterMelon_4_3.txt',sep='	')
    data_full = df.values[:,1:].tolist()
    dataSet = data_full[:]
    labels_full = df.columns[1:-1].tolist()
    labels = labels_full[:]
    myTree = createTree(dataSet,labels,data_full,labels_full)
    
    createPlot(myTree)
    

    参考:

    ID3决策树

    绘图子程序

    python绘制决策树

    效果

    坚持
  • 相关阅读:
    Flutter DraggableScrollableSheet 可滚动对象的容器
    Flutter 避免阻塞ui线程
    Android Kotlin 数据驱动模板
    ng mock服务器数据
    rxjs 常用的subject
    Flutter 在同一页面显示List和Grid
    dart2native 使用Dart 在macOS,Windows或Linux上创建命令行工具
    Flutter 创建透明的路由页面
    ng 发布组件库
    js实现单张或多张图片持续无缝滚动
  • 原文地址:https://www.cnblogs.com/liudianfengmang/p/12900717.html
Copyright © 2011-2022 走看看