zoukankan      html  css  js  c++  java
  • (二)《机器学习》(周志华)第4章 决策树 笔记 理论及实现——“西瓜树”——CART决策树

    CART决策树

    (一)《机器学习》(周志华)第4章 决策树 笔记 理论及实现——“西瓜树”

    参照上一篇ID3算法实现的决策树(点击上面链接直达),进一步实现CART决策树。

    其实只需要改动很小的一部分就可以了,把原先计算信息熵和信息增益的部分换做计算基尼指数,选择最优属性的时候,选择最小的基尼指数即可。

    #导入模块
    import pandas as pd
    import numpy as np
    from collections import Counter
    
    #数据获取与处理
    def getData(filePath):
        data = pd.read_excel(filePath)
        return data
    
    def dataDeal(data):
        dataList = np.array(data).tolist()
        dataSet = [element[1:] for element in dataList]
        return dataSet
    
    #获取属性名称
    def getLabels(data):
        labels = list(data.columns)[1:-1]
        return labels
    
    #获取类别标记
    def targetClass(dataSet):
        classification = set([element[-1] for element in dataSet])
        return classification
        
    #将分支结点标记为叶结点,选择样本数最多的类作为类标记
    def majorityRule(dataSet):
        mostKind = Counter([element[-1] for element in dataSet]).most_common(1)
        majorityKind = mostKind[0][0]
        return majorityKind
    
    ##计算基尼值
    def calculateGini(dataSet):
        classColumnCnt = Counter([element[-1] for element in dataSet])
        gini = 0
        for symbol in classColumnCnt:
            p_k = classColumnCnt[symbol]/len(dataSet)
            gini = gini+p_k**2
        gini = 1-gini
        return gini
    
    #子数据集构建
    def makeAttributeData(dataSet,value,iColumn):
        attributeData = []
        for element in dataSet:
            if element[iColumn]==value:
                row = element[:iColumn]
                row.extend(element[iColumn+1:])
                attributeData.append(row)
        return attributeData
    
    #计算基尼指数
    def GiniIndex(dataSet,iColumn):
        index = 0.0
        attribute = set([element[iColumn] for element in dataSet])
        for value in attribute:
            attributeData = makeAttributeData(dataSet,value,iColumn)
            index = index+len(attributeData)/len(dataSet)*calculateGini(attributeData)
        return index
    
    #选择最优属性                
    def selectOptimalAttribute(dataSet,labels):
        bestGini = []
        for iColumn in range(0,len(labels)):#不计最后的类别列
            index = GiniIndex(dataSet,iColumn)
            bestGini.append(index)
        sequence = bestGini.index(min(bestGini))
        return sequence
        
    #建立决策树
    def createTree(dataSet,labels):
        classification = targetClass(dataSet) #获取类别种类(集合去重)
        if len(classification) == 1:
            return list(classification)[0]
        if len(labels) == 1:
            return majorityRule(dataSet)#返回样本种类较多的类别
        sequence = selectOptimalAttribute(dataSet,labels)
        optimalAttribute = labels[sequence]
        del(labels[sequence])
        myTree = {optimalAttribute:{}}
        attribute = set([element[sequence] for element in dataSet])
        for value in attribute:
            subLabels = labels[:]
            myTree[optimalAttribute][value] =  
                    createTree(makeAttributeData(dataSet,value,sequence),subLabels)
        return myTree
    
    #定义主函数
    def main():
        filePath = 'watermelonData.xls'
        data = getData(filePath)
        dataSet = dataDeal(data)
        labels = getLabels(data)
        myTree = createTree(dataSet,labels)
        return myTree
    
    #读取数据文件并转换为列表(含有汉字的,使用CSV格式读取容易出错)
    if __name__ == '__main__':
        myTree = main()
        print (myTree)

     结果竟然是一样的,深度怀疑做错了。

  • 相关阅读:
    JS判断页面是否加载完成
    简单的前端验证码
    如何让旧浏览器支持HTML5新标签
    JSON使用(4)
    JSON语法(3)
    JSON简介(2)
    JSON教程(1)
    jQuery-noConflict()
    jQuery
    jQuery
  • 原文地址:https://www.cnblogs.com/dennis-liucd/p/7944033.html
Copyright © 2011-2022 走看看