zoukankan      html  css  js  c++  java
  • 机器学习-决策树-C4.5决策树

    机器学习-决策树-C4.5决策树

    针对ID3算法存在的一些问题,1993年,QuinlanID3算法改进为C4.5算法。该算法成功地解决了ID3算法遇到的诸多问题,发展成为机器学习的十大算法之一。

    C4.5并没有改变ID3的算法逻辑,基本的程序结构仍与ID3相同,但在节点的划分标准上做了改进。C4.5使用信息增益率(GainRatio)来替代信息增益(Gain)进行特征的选择,克服了信息增益选择特征时偏向于特征值个数较多的不足。

    信息增益率:

    GainRatio(S,A) = Gain(S,A) / SplitInfo(S,A)

    其中Gain(S,A)就是ID3算法中的信息增益,而划分信息SplitInfo(S,A)代表了按照特征A划分样本集S的广度和均匀性。

     

    其中SiSc时特征AC个不同值构成的样本子集。

    代码

    # C4.5决策树,使用信息增益率确定最优特征
    from numpy import *
    import math
    import copy
    import pickle
    
    class C45DTree(object):
        def __init__(self): # 构造方法
            self.tree = {}  # 生成的树
            self.dataSet = []   # 数据集
            self.labels = []    # 标签集
    
        # 数据导入函数
        def loadDataSet(self, path, labels):
            recordlist = []
            fp = open(path, "r")  # 读取文件内容
            content = fp.read()
            fp.close()
            rowlist = content.splitlines()  # 按行转换为一维表
            recordlist = [row.split("	") for row in rowlist if row.strip()]
            self.dataSet = recordlist
            self.labels = labels
    
        # 执行决策树函数
        def train(self):
            labels = copy.deepcopy(self.labels)
            self.tree = self.buildTree(self.dataSet, labels)
    
            # 创建决策树主程序
    
        def buildTree(self, dataSet, labels):
            cateList = [data[-1] for data in dataSet]  # 抽取源数据集的决策标签列
            # 程序终止条件1:如果classList只有一种决策标签,停止划分,返回这个决策标签
            if cateList.count(cateList[0]) == len(cateList):
                return cateList[0]
            # 程序终止条件2:如果数据集的第一个决策标签只有一个,则返回这个决策标签
            if len(dataSet[0]) == 1:
                return self.maxCate(cateList)
            # 算法核心:
            bestFeat,featValueList = self.getBestFeat(dataSet)  # 返回数据集的最优特征轴
            bestFeatLabel = labels[bestFeat]
            tree = {bestFeatLabel: {}}
            del (labels[bestFeat])
            # 抽取最优特征轴的列向量
            for value in featValueList:  # 决策树递归生长
                subLabels = labels[:]  # 将删除后的特征类别集建立子类别集
                # 按最优特征列和值分隔数据集
                splitDataset = self.splitDataSet(dataSet, bestFeat, value)
                subTree = self.buildTree(splitDataset, subLabels)  # 构建子树
                tree[bestFeatLabel][value] = subTree
            return tree
    
        # 计算出现次数最多的类别标签
        def maxCate(self, catelist):
            items = dict([(catelist.count(i), i) for i in catelist])
            return items[max(items.keys())]
    
        # 计算信息熵
        def computeEntropy(self, dataSet):
            datalen = float(len(dataSet))
            cateList = [data[-1] for data in dataSet]  # 从数据集中得到类别标签
            # 得到类别为key、出现次数value的字典
            items = dict([(i, cateList.count(i)) for i in cateList])
            infoEntropy = 0.0  # 初始化香农熵
            for key in items:  # 香农熵:
                prob = float(items[key]) / datalen
                infoEntropy -= prob * math.log(prob, 2)
            return infoEntropy
    
        # 划分数据集;分隔数据集;删除特征轴所在的数据列,返回剩余的数据集
        # dataSet:数据集   axis:特征轴    value:特征轴的取值
        def splitDataSet(self, dataSet, axis, value):
            rtnList = []
            for featVec in dataSet:
                if featVec[axis] == value:
                    rFeatVec = featVec[:axis]  # list操作:提取0~(axis-1)的元素
                    rFeatVec.extend(featVec[axis + 1:])  # list操作:将特征轴(列)之后的元素加回
                    rtnList.append(rFeatVec)
            return rtnList
    
        # 计算划分信息(SpilitInfo)
        def computeSplitInfo(self, featureVList):
            numEntries = len(featureVList)
            featureValueListSetList = list(set(featureVList))
            valueCounts = [featureVList.count(featVec) for featVec in featureValueListSetList]
            # 计算香农熵
            pList = [float(item) / numEntries for item in valueCounts]
            lList = [item * math.log(item, 2) for item in pList]
            splitInfo = -sum(lList)
            return splitInfo, featureValueListSetList
    
        # 使用信息增益率划分最优节点
        def getBestFeat(self, dataSet):
            Num_feats = len(dataSet[0][:-1])
            totality = len(dataSet)
            BaseEntropy = self.computeEntropy(dataSet)
            ConditionEntropy = []   # 初始化条件熵
            splitInfo = []  # 计算信息增益率
            allFeatVList = []
            for f in range(Num_feats):
                featList = [example[f] for example in dataSet]
                [splitI, featureValueList] = self.computeSplitInfo(featList)
                allFeatVList.append(featureValueList)
                splitInfo.append(splitI)
                resultGain = 0.0
                for value in featureValueList:
                    subSet = self.splitDataSet(dataSet, f, value)
                    appearNum = float(len(subSet))
                    subEntropy = self.computeEntropy(subSet)
                    resultGain += (appearNum/totality) * subEntropy
                ConditionEntropy.append(resultGain) # 总条件熵
            infoGainArray = BaseEntropy * ones(Num_feats) - array(ConditionEntropy)
            infoGainRatio = infoGainArray / array(splitInfo)    # C4.5信息增益的计算
            bestFeatureIndex = argsort(-infoGainRatio)[0]
            return bestFeatureIndex, allFeatVList[bestFeatureIndex]
    
        # 分类
        def predict(self, inputTree, featLabels, testVec):
            root = list(inputTree.keys())[0]  # 树根节点
            secondDict = inputTree[root]  # value-子树结构或分类标签
            featIndex = featLabels.index(root)  # 根节点在分类标签集中的位置
            key = testVec[featIndex]  # 测试集数组取值
            valueOfFeat = secondDict[key]
            if isinstance(valueOfFeat, dict):
                classLabel = self.predict(valueOfFeat, featLabels, testVec)  # 递归分类
            else:
                classLabel = valueOfFeat
            return classLabel
    
    
        # 持久化
        def storeTree(self, inputTree, filename):
            fw = open(filename, 'wb')
            pickle.dump(inputTree, fw)
            fw.close()
    
        # 从文件抓取树
        def grabTree(self, filename):
            fr = open(filename, 'rb')
            return pickle.load(fr)
    
    #训练
    dtree = C45DTree()
    dtree.loadDataSet("/Users/FengZhen/Desktop/accumulate/机器学习/决策树/决策树训练集.txt", ["age", "revenue", "student", "credit"])
    dtree.train()
    print(dtree.tree)
    
    #持久化
    dtree.storeTree(dtree.tree, "/Users/FengZhen/Desktop/accumulate/机器学习/决策树/决策树C45.tree")
    mytree = dtree.grabTree("/Users/FengZhen/Desktop/accumulate/机器学习/决策树/决策树C45.tree")
    print(mytree)
    
    #测试
    labels = ["age", "revenue", "student", "credit"]
    vector = ['0','1','0','0']
    print(dtree.predict(mytree, labels, vector))
  • 相关阅读:
    Linux 文件及目录管理命令基础
    MHA高可用及读写分离
    MySQL的备份和回复
    mysql的主从复制
    MySQL索引管理及执行计划
    [LeetCode]Linked List Cycle II解法学习
    浅谈reverse_iterator的base()函数
    [LeetCode]LRU Cache有个问题,求大神解答【已解决】
    分享一篇不错的博文《写给准备参加秋招的学弟学妹们~一定要来看哦~》
    将博客搬至CSDN
  • 原文地址:https://www.cnblogs.com/EnzoDin/p/12431768.html
Copyright © 2011-2022 走看看