zoukankan      html  css  js  c++  java
  • 我的机器学习之旅(七):决策树(下)

    ID3、C4.5生成决策树的算法,当训练数据量很大的时候,建立的决策树模型往往非常复杂,树的深度很大。此时虽然对训练数据拟合得很好,但是其泛化能力即预测新数据的能力并不一定很好,也就是出现了过拟合现象。这个时候我们就需要对决策树进行剪枝处理以简化模型。另外,CART算法也可用于建立回归树。

    CART算法 

    CART,即分类与回归树(classification and regression tree),也是一种应用很广泛的决策树学习方法。但是CART算法比较强大,既可用作分类树,也可以用作回归树。作为分类树时,其本质与ID3、C4.5并有多大区别,只是选择特征的依据不同而已。当CART用作回归树时,以最小平方误差作为划分样本的依据。

    1.分类树

    (1)基尼指数

    1、是一种不等性度量;
    2、通常用来度量收入不平衡,可以用来度量任何不均匀分布;
    3、是介于0~1之间的数,0-完全相等,1-完全不相等;
    4、总体内包含的类别越杂乱,GINI指数就越大(跟熵的概念很相似)

    分类树采用基尼指数选择最优特征。假设有K个类,样本点属于第k类的概率为pk,则概率分布的基尼指数定义为:

    (Gini(p)=sum_{k=1}^K p_k(1-p_k)=1-sum_{k=1}^K p_k^2)

    对于给定的样本集合D,其基尼指数为:

    (Gini(D)=1-sum_{k=1}^Kleft(frac{|C_k|}{|D|} ight)^2)

    这里,Ck是D中属于第k类的样本子集,K是类的个数。

    基尼指数Gini(D)表示集合D的不确定性,基尼指数Gini(D,A)表示经过A=a分割后集合D的不确定性。基尼指数越大,样本的不确定性也就越大。

    def calcGini(dataSet):
        '''
                计算基尼指数
        :param dataSet:数据集
        :return: 计算结果
        '''
        numEntries = len(dataSet)
        labelCounts = {}
        for featVec in dataSet: # 遍历每个实例,统计标签的频数
            currentLabel = featVec[-1]
            if currentLabel not in labelCounts.keys(): 
                labelCounts[currentLabel] = 0
            labelCounts[currentLabel] += 1
        Gini = 1.0
        for key in labelCounts:
            prob = float(labelCounts[key]) / numEntries
            Gini -= prob * prob # 以2为底的对数
        return Gini

    那么在给定特征A的条件下,集合D的基尼指数定义为:

    (Gini(D,A)=frac{|D_1|}{D}Gini(D_1)+frac{|D_2|}{|D|}Gini(D_2))

    基尼指数Gini(D)表示集合D的不确定性,基尼指数Gini(D,A)表示经A=a分割后集合D的不确定性。基尼指数值越大,样本集合的不确定性也就越大,这一点与熵相似。

    def calcGiniWithFeat(dataSet, feature, value):
        '''
        计算给定特征下的基尼指数
        :param dataSet:数据集
        :param feature:特征维度
        :param value:该特征变量所取的值
         :return: 计算结果
        '''
        numEntries=len(dataSet)
        featList = [example[feature] for example in dataSet] # 第i维特征列表
        uniqueVals = set(featList) # 转换成集合
        Gini_feat=0
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet, feature, value)
            prob = len(subDataSet ) / numEntries
            Gini_feat+= prob * calcGini(subDataSet ) # 以2为底的对数
        return Gini_feat
    
    
    def chooseBestSplit(dataSet):
        numFeatures = len(dataSet[0])-1
        bestGini = float('inf'); bestFeat = 0; bestValue = 0; newGini = 0
        for i in range(numFeatures):
            featList = [example[i] for example in dataSet]
            uniqueVals = set(featList)
            for splitVal in uniqueVals:
                newGini = calcGiniWithFeat(dataSet, i, splitVal)
                if newGini < bestGini:
                    bestFeat = i
                    bestGini = newGini
        return bestFeat
    
    def createTree_CART(dataSet,labels):
        '''
                创建决策树
        :param: dataSet:训练数据集
        :return: labels:所有的类标签
        '''
        classList = [example[-1] for example in dataSet]
        if classList.count(classList[0]) == len(classList): 
            return classList[0]             # 第一个递归结束条件:所有的类标签完全相同
        if len(dataSet[0]) == 1:        
            return majorityCnt(classList)   # 第二个递归结束条件:用完了所有特征
        bestFeat = chooseBestSplit(dataSet)   # 最优划分特征
        bestFeatLabel = labels[bestFeat]
        myTree = {bestFeatLabel:{}}         # 使用字典类型储存树的信息
        label_tmp=labels[:bestFeat]+labels[bestFeat+1:]
        featValues = [example[bestFeat] for example in dataSet]
        uniqueVals = set(featValues)
        for value in uniqueVals:
            subLabels = label_tmp[:]       # 复制所有类标签,保证每次递归调用时不改变原始列表的内容
            myTree[bestFeatLabel][value] = createTree_CART(splitDataSet(dataSet, bestFeat, value),subLabels)
        return myTree

    剪枝

    在决策树学习中将已生成的树进行简化的过程称为剪枝。决策树的剪枝往往通过极小化决策树的损失函数或代价函数来实现。实际上剪枝的过程就是一个动态规划的过程:从叶结点开始,自底向上地对内部结点计算预测误差以及剪枝后的预测误差,如果两者的预测误差是相等或者剪枝后预测误差更小,当然是剪掉的好。但是如果剪枝后的预测误差更大,那就不要剪了。剪枝后,原内部结点会变成新的叶结点,其决策类别由多数表决法决定。不断重复这个过程往上剪枝,直到预测误差最小为止。

    import operator
    def isTree(obj):
        return (type(obj).__name__=='dict')
    def testMajor(major,testData): 
        errorCount = 0.0 
        for i in range(len(testData)): 
            if major != testData[i][-1]: 
                errorCount += 1 
        return float(errorCount)
    
    
    def calcTestErr(myTree,testData,labels):
        errorCount = 0.0
        for i in range(len(testData)): 
            if classify(myTree,labels,testData[i]) != testData[i][-1]:
                errorCount += 1 
        return float(errorCount)
    
    def pruningTree(inputTree,dataSet,testData,labels):  
        import copy
        firstStr = list(inputTree.keys())[0]  
        secondDict = inputTree[firstStr]        # 获取子树
        classList = [example[-1] for example in dataSet]  
        featKey = copy.deepcopy(firstStr)  
        labelIndex = labels.index(featKey)  
        subLabels = copy.deepcopy(labels)
        label_tmp=labels[:labelIndex]+labels[labelIndex+1:] 
        for key in list(secondDict.keys()):  
            if isTree(secondDict[key]):
                # 深度优先搜索,递归剪枝
                subDataSet = splitDataSet(dataSet,labelIndex,key)
                subTestSet = splitDataSet(testData,labelIndex,key)
                if len(subDataSet) > 0 and len(subTestSet) > 0:
                    inputTree[firstStr][key] = pruningTree(secondDict[key],subDataSet,subTestSet,copy.deepcopy(label_tmp))
        if calcTestErr(inputTree,testData,subLabels) < testMajor(majorityCnt(classList),testData):
            # 剪枝后的误差反而变大,不作处理,直接返回
            return inputTree 
        else:
            # 剪枝,原父结点变成子结点,其类别由多数表决法决定
            return majorityCnt(classList)                                    

    回归树

     回归树的生成实际上也是贪心算法。与分类树不同的是回归树处理的数据连续分布的。

    算法步骤:

    from numpy import *
    def regLeaf(dataSet):           #建立叶节点函数,value为所有y的均值  
        return mean(dataSet[:,-1])  
    def regErr(dataset):  
        return var(dataset[:, -1]) * shape(dataset)[0]#y的方差×y的数量=平方误差  
    
    def binSplitDataset(dataSet, feature, value):#以这一列的每个值为界限,大于它和小于它的值,返回的是以这个特征值为界限分割的数据集  
        mat0 = dataSet[nonzero(dataSet[:, feature] > value)[0], :]#返回索引,切割数据集  
        mat1 = dataSet[nonzero(dataSet[:, feature] <= value)[0], :]  
        return mat0, mat1  
    
    
    
    def chooseBestsplit(dataset, leafType=regLeaf, errtype = regErr,ops=(1, 4)):#找到最好的分割叶子节点  
        tolS = ops[0]##允许的误差下降值  
        tolN = ops[1] #切分的最小样本数  
        #判断是否可以分开二叉树  
        # print(len(set(dataset[:, -1].T.tolist()[0])))#不是一下子分开,然后就是先分割整个数据集,然后分割左边,然后右边  
        if len(set(dataset[:, -1].T.tolist())) == 1:  # #如果剩余特征值的数量等于1,不需要再切分直接返回,(退出条件1)  
            return None, leafType(dataset)  
        m, n = shape(dataset)#行列数  
        S = errtype(dataset)#计算平方差  
        bestS = inf  
        bestIndex = 0  
        bestValue = 0  
        for featIndex in range(n - 1):#特征索引  
            for splitVal in set((dataset[:, featIndex].T.tolist())):  #每一列的每个值  
                mat0, mat1 = binSplitDataset(dataset, featIndex, splitVal)#整个数据集,第几列,那一列的每个值  
                if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue#样本数最小限制  
                # print(errtype(mat0))  
                newS = errtype(mat0) + errtype(mat1)#计算平方误差  
                if newS < bestS:  
                    bestIndex = featIndex  
                    bestValue = splitVal  
                    bestS = newS  
      
        if (S - bestS) < tolS:#如果切分后误差效果下降不大,则取消切分,直接创建叶结点  
            return None, leafType(dataset)  
        mat0, mat1 = binSplitDataset(dataset, bestIndex, bestValue)  # 按照保存的最佳分割来划分集合  
        # #判断切分后子集大小,小于最小允许样本数停止切分3  
        if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):  
            return None, leafType(dataset)  
        # 返回最佳二元切割的bestIndex和bestValue  
        return bestIndex, bestValue#返回特征编号和用于切分的特征值  
    
    def createTree_reg(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):#assume dataSet is NumPy Mat so we can array filtering  
        feat, val = chooseBestsplit(dataSet, leafType, errType, ops)    #采用最佳分割,将数据集分成两个部分  
        if feat == None: return val     #递归结束条件  
        retTree = {}                    #建立返回的字典  
        retTree['spInd'] = feat  
        retTree['spVal'] = val  
        lSet, rSet = binSplitDataset(dataSet, feat, val)    #得到左子树集合和右子树集合  
        retTree['left'] = createTree_reg(lSet, leafType, errType, ops)      #递归左子树  
        retTree['right'] = createTree_reg(rSet, leafType, errType, ops)     #递归右子树  
        return retTree  

    后剪枝

    def getMean(tree):
        if isTree(tree['right']): tree['right'] = getMean(tree['right'])
        if isTree(tree['left']): tree['left'] = getMean(tree['left'])
        return (tree['left']+tree['right'])/2.0
    
    #树的后剪枝,  
    def prune(tree, testData):#待剪枝的树和剪枝所需的测试数据  
        if shape(testData)[0] == 0:# 确认数据集非空  
            return getMean(tree)  
        #假设发生过拟合,采用测试数据对树进行剪枝  
        if (isTree(tree['right']) or isTree(tree['left'])): #左右子树非空  
            lSet, rSet = binSplitDataset(testData, tree['spInd'], tree['spVal'])#按照索引,和值分割数据集  
        if isTree(tree['left']):  
            tree['left'] = prune(tree['left'], lSet)  
        if isTree(tree['right']):  
            tree['right'] = prune(tree['right'], rSet)  
        #剪枝后判断是否还是有子树  
        if not isTree(tree['left']) and not isTree(tree['right']):#只要有一个空  
            lSet, rSet = binSplitDataset(testData, tree['spInd'], tree['spVal'])  
            #判断是否融合  
            errorNoMerge = sum(power(lSet[:, -1] - tree['left'], 2)) + sum(power(rSet[:, -1] - tree['right'], 2))#未熔合的方差  
            treeMean = (tree['left'] + tree['right']) / 2.0#平均值  
            errorMerge = sum(power(testData[:, -1] - treeMean, 2))#查看融合后的方差  
            #如果合并后误差变小,融合,将两个叶子的均值作为节点  
            if errorMerge < errorNoMerge:  
                print("merging")  
                return treeMean  
            else:  
                return tree  
        else:  
            return tree  

     

    参考:https://www.cnblogs.com/pinard/p/6050306.html

       https://www.cnblogs.com/pinard/p/6053344.html

  • 相关阅读:
    pandas 学习 第2篇:Series -(创建,属性,转换和索引)
    pandas 学习 第1篇:pandas基础
    linux中的软连接和硬链接
    分布式与集群的简单讲解
    Redis持久化
    CentOS7安装后无法使用鼠标选中,复制问题解决
    centos 7 安装 ifconfig 管理命令
    ES分布式文档数据库讲解
    Storm,Spark和Flink三种流式大数据处理框架对比
    mvn常见参数命令讲解
  • 原文地址:https://www.cnblogs.com/allenren/p/8662504.html
Copyright © 2011-2022 走看看