zoukankan      html  css  js  c++  java
  • cart树回归及其剪枝的python实现

    前言

           前文讨论的回归算法都是全局且针对线性问题的回归,即使是其中的局部加权线性回归法,也有其弊端(具体请参考前文)

           采用全局模型会导致模型非常的臃肿,因为需要计算所有的样本点,而且现实生活中很多样本都有大量的特征信息。

           另一方面,实际生活中更多的问题都是非线性问题。

           针对这些问题,有了树回归系列算法。

    回归树

           在先前决策树的学习中,构建树是采用的 ID3 算法。在回归领域,该算法就有个问题,就是派生子树是按照所有可能值来进行派生。

           因此 ID3 算法无法处理连续性数据。

           故可使用二元切分法,以某个特定值为界进行切分。在这种切分法下,子树个数小于等于2。

           除此之外,再修改择优原则香农熵 (因为数据变为连续型的了),便可将树构建成一棵可用于回归的树,这样一棵树便叫做回归树。

           构建回归树的伪代码:

    1 找到最佳的待切分特征:
    2     如果该节点不能再分,将此节点存为叶节点。
    3     执行二元切分
    4     左右子树分别递归调用此函数

           二元切分的伪代码:

    1 对每个特征:
    2     对每个特征值:
    3         将数据集切成两份
    4         计算切分误差
    5         如果当前误差小于最小误差,则更新最佳切分以及最小误差。

           特别说明,终止划分 (并直接建立叶节点)有三种情况:
                  1. 特征值划分完毕
                  2. 划分子集太小
                  3. 划分后误差改进不大
           这几个操作被称做 "预剪枝"。
      下面给出一个完整的回归树的小程序:

    复制代码
      1 #!/usr/bin/env python
      2 # -*- coding:UTF-8 -*-
      3 
      4 '''
      5 Created on 20**-**-**
      6 
      7 @author: fangmeng
      8 '''
      9 
     10 from numpy import *
     11 
     12 def loadDataSet(fileName):
     13     '载入测试数据'
     14     
     15     dataMat = []
     16     fr = open(fileName)
     17     for line in fr.readlines():
     18         curLine = line.strip().split('	')
     19         # 所有元素转换为浮点类型(函数编程)
     20         fltLine = map(float,curLine)
     21         dataMat.append(fltLine)
     22     return dataMat
     23 
     24 #============================
     25 # 输入:
     26 #        dataSet: 待切分数据集
     27 #        feature: 切分特征序号
     28 #        value:    切分值
     29 # 输出:
     30 #        mat0,mat1: 切分结果
     31 #============================
     32 def binSplitDataSet(dataSet, feature, value):
     33     '切分数据集'
     34     
     35     mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0],:][0]
     36     mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0],:][0]
     37     return mat0,mat1
     38 
     39 #========================================
     40 # 输入:
     41 #        dataSet: 数据集
     42 # 输出:
     43 #        mean(dataSet[:,-1]): 均值(也就是叶节点的内容)
     44 #========================================
     45 def regLeaf(dataSet):
     46     '生成叶节点'
     47     
     48     return mean(dataSet[:,-1])
     49 
     50 #========================================
     51 # 输入:
     52 #        dataSet: 数据集
     53 # 输出:
     54 #        var(dataSet[:,-1]) * shape(dataSet)[0]: 平方误差
     55 #========================================
     56 def regErr(dataSet):
     57     '计算平方误差'
     58     
     59     return var(dataSet[:,-1]) * shape(dataSet)[0]
     60 
     61 #========================================
     62 # 输入:
     63 #        dataSet: 数据集
     64 #        leafType: 叶子节点生成器
     65 #        errType: 误差统计器
     66 #        ops: 相关参数
     67 # 输出:
     68 #        bestIndex: 最佳划分特征 
     69 #        bestValue: 最佳划分特征值
     70 #========================================
     71 def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
     72     '选择最优划分'
     73     
     74     # 获得相关参数中的最大样本数和最小误差效果提升值
     75     tolS = ops[0]; 
     76     tolN = ops[1]
     77     
     78     # 如果所有样本点的值一致,那么直接建立叶子节点。
     79     if len(set(dataSet[:,-1].T.tolist()[0])) == 1: 
     80         return None, leafType(dataSet)
     81     
     82     m,n = shape(dataSet)
     83     # 当前误差
     84     S = errType(dataSet)
     85     # 最小误差
     86     bestS = inf; 
     87     # 最小误差对应的划分方式
     88     bestIndex = 0; 
     89     bestValue = 0
     90     
     91     # 对于所有特征
     92     for featIndex in range(n-1):
     93         # 对于某个特征的所有特征值
     94         for splitVal in set(dataSet[:,featIndex]):
     95             # 划分
     96             mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
     97             # 如果划分后某个子集中的个数不达标
     98             if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue
     99             # 当前划分方式的误差
    100             newS = errType(mat0) + errType(mat1)
    101             # 如果这种划分方式的误差小于最小误差
    102             if newS < bestS: 
    103                 bestIndex = featIndex
    104                 bestValue = splitVal
    105                 bestS = newS
    106     
    107     # 如果当前划分方式还不如不划分时候的误差效果
    108     if (S - bestS) < tolS: 
    109         return None, leafType(dataSet)
    110     # 按照最优划分方式进行划分
    111     mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
    112     # 如果划分后某个子集中的个数不达标
    113     if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):
    114         return None, leafType(dataSet)
    115     
    116     return bestIndex,bestValue
    117 
    118 #========================================
    119 # 输入:
    120 #        dataSet: 数据集
    121 #        leafType: 叶子节点生成器
    122 #        errType: 误差统计器
    123 #        ops: 相关参数
    124 # 输出:
    125 #        retTree: 回归树
    126 #========================================
    127 def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
    128     '构建回归树'
    129     
    130     # 选择最佳划分方式
    131     feat, val = chooseBestSplit(dataSet, leafType, errType, ops)
    132     # feat为None的时候无需划分返回叶子节点
    133     if feat == None: return val #if the splitting hit a stop condition return val
    134     
    135     # 递归调用构建函数并更新树
    136     retTree = {}
    137     retTree['spInd'] = feat
    138     retTree['spVal'] = val
    139     lSet, rSet = binSplitDataSet(dataSet, feat, val)
    140     retTree['left'] = createTree(lSet, leafType, errType, ops)
    141     retTree['right'] = createTree(rSet, leafType, errType, ops)
    142     
    143     return retTree  
    144 
    145 def test():
    146     '展示结果'
    147     
    148     # 载入数据
    149     myDat = loadDataSet('/home/fangmeng/ex0.txt')
    150     # 构建回归树
    151     myDat = mat(myDat)
    152     
    153     print createTree(myDat)
    154     
    155     
    156 if __name__ == '__main__':
    157     test()
    复制代码

           测试结果:

    回归树的优化工作 - 剪枝

           在上面的代码中,终止递归的条件中已经加入了重重的 "剪枝" 工作。

           这些在建树的时候的剪枝操作通常被成为预剪枝。这是很有很有必要的,经过预剪枝的树几乎就是没有预剪枝树的大小的百分之一甚至更小,而性能相差无几。

           而在树建立完毕之后,基于训练集和测试集能做更多更高效的剪枝工作,这些工作叫做 "后剪枝"。

           可见,剪枝是一项较大的工作量,是对树非常关键的优化过程。

           后剪枝过程的伪代码如下:

    1 基于已有的树切分测试数据:
    2     如果存在任一子集是一棵树,则在该子集上递归该过程。
    3     计算将当前两个叶节点合并后的误差
    4     计算不合并的误差
    5     如果合并会降低误差,则将叶节点合并。

           具体实现函数如下:

    复制代码
     1 #===================================
     2 # 输入:
     3 #        obj: 判断对象
     4 # 输出:
     5 #        (type(obj).__name__=='dict'): 判断结果
     6 #===================================
     7 def isTree(obj):
     8     '判断对象是否为树类型'
     9     
    10     return (type(obj).__name__=='dict')
    11 
    12 #===================================
    13 # 输入:
    14 #        tree: 处理对象
    15 # 输出:
    16 #        (tree['left']+tree['right'])/2.0: 坍塌后的替代值
    17 #===================================
    18 def getMean(tree):
    19     '坍塌处理'
    20     
    21     if isTree(tree['right']): tree['right'] = getMean(tree['right'])
    22     if isTree(tree['left']): tree['left'] = getMean(tree['left'])
    23     
    24     return (tree['left']+tree['right'])/2.0
    25   
    26 #===================================
    27 # 输入:
    28 #        tree: 处理对象
    29 #        testData: 测试数据集
    30 # 输出:
    31 #        tree: 剪枝后的树
    32 #===================================  
    33 def prune(tree, testData):
    34     '后剪枝'
    35     
    36     # 无测试数据则坍塌此树
    37     if shape(testData)[0] == 0: 
    38         return getMean(tree)
    39     
    40     # 若左/右子集为树类型
    41     if (isTree(tree['right']) or isTree(tree['left'])):
    42         # 划分测试集
    43         lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
    44     # 在新树新测试集上递归进行剪枝
    45     if isTree(tree['left']): tree['left'] = prune(tree['left'], lSet)
    46     if isTree(tree['right']): tree['right'] =  prune(tree['right'], rSet)
    47     
    48     # 如果两个子集都是叶子的话,则在进行误差评估后决定是否进行合并。
    49     if not isTree(tree['left']) and not isTree(tree['right']):
    50         lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
    51         errorNoMerge = sum(power(lSet[:,-1] - tree['left'],2)) +sum(power(rSet[:,-1] - tree['right'],2))
    52         treeMean = (tree['left']+tree['right'])/2.0
    53         errorMerge = sum(power(testData[:,-1] - treeMean,2))
    54         if errorMerge < errorNoMerge: 
    55             return treeMean
    56         else: return tree
    57     else: return tree
    复制代码

    模型树

           这也是一种很棒的树回归算法。

           该算法将所有的叶子节点不是表述成一个值,而是对叶子部分节点建立线性模型。比如可以是最小二乘法的基本线性回归模型。

           这样在叶子节点里存放的就是一组线性回归系数了。非叶子节点部分构造就和回归树一样。

           这个是上面建立回归树算法的函数头:

           createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):

           对于模型树,只需要修改修改 leafType(叶节点构造器) 和 errType(误差分析器) 的实现即可,分别对应如下modelLeaf 函数和 modelErr 函数:

    复制代码
     1 #=========================
     2 # 输入:
     3 #        dataSet: 测试集
     4 # 输出:
     5 #        ws,X,Y: 回归模型
     6 #=========================
     7 def linearSolve(dataSet):
     8     '辅助函数,用于构建线性回归模型。'
     9     
    10     m,n = shape(dataSet)
    11     X = mat(ones((m,n))); 
    12     Y = mat(ones((m,1)))
    13     X[:,1:n] = dataSet[:,0:n-1]; 
    14     Y = dataSet[:,-1]
    15     xTx = X.T*X
    16     if linalg.det(xTx) == 0.0:
    17         raise NameError('系数矩阵不可逆')
    18     ws = xTx.I * (X.T * Y)
    19     return ws,X,Y
    20 
    21 #=======================
    22 # 输入:
    23 #       dataSet: 数据集
    24 # 输出:
    25 #        ws: 回归系数
    26 #=======================
    27 def modelLeaf(dataSet):
    28     '叶节点构造器'
    29     
    30     ws,X,Y = linearSolve(dataSet)
    31     return ws
    32 
    33 #=======================================
    34 # 输入:
    35 #       dataSet: 数据集
    36 # 输出:
    37 #        sum(power(Y - yHat,2)): 平方误差
    38 #=======================================
    39 def modelErr(dataSet):
    40     '误差分析器'
    41     
    42     ws,X,Y = linearSolve(dataSet)
    43     yHat = X * ws
    44     return sum(power(Y - yHat,2))
    复制代码

    回归树 / 模型树的使用

           前面的工作主要介绍了两种树 - 回归树,模型树的构建,下面进一步学习如何利用这些树来进行预测。

           当然,本质也就是递归遍历树。

           下为遍历代码,通过修改参数设置要使用并传递进来的是回归树还是模型树:

    复制代码
     1 #==============================
     2 # 输入:
     3 #       model: 叶子
     4 #       inDat: 测试数据
     5 # 输出:
     6 #        float(model): 叶子值
     7 #==============================
     8 def regTreeEval(model, inDat):
     9     '回归树预测'
    10     
    11     return float(model)
    12 
    13 #==============================
    14 # 输入:
    15 #       model: 叶子
    16 #       inDat: 测试数据
    17 # 输出:
    18 #        float(X*model): 叶子值
    19 #==============================
    20 def modelTreeEval(model, inDat):
    21     '模型树预测'
    22     n = shape(inDat)[1]
    23     X = mat(ones((1,n+1)))
    24     X[:,1:n+1]=inDat
    25     return float(X*model)
    26 
    27 #==============================
    28 # 输入:
    29 #        tree: 待遍历树
    30 #        inDat: 测试数据
    31 #        modelEval: 叶子值获取器
    32 # 输出:
    33 #        分类结果
    34 #==============================
    35 def treeForeCast(tree, inData, modelEval=regTreeEval):
    36     '使用回归/模型树进行预测 (modelEval参数指定)'
    37     
    38     # 如果非树类型,返回值。
    39     if not isTree(tree): return modelEval(tree, inData)
    40     
    41     # 左遍历
    42     if inData[tree['spInd']] > tree['spVal']:
    43         if isTree(tree['left']): return treeForeCast(tree['left'], inData, modelEval)
    44         else: return modelEval(tree['left'], inData)
    45         
    46     # 右遍历
    47     else:
    48         if isTree(tree['right']): return treeForeCast(tree['right'], inData, modelEval)
    49         else: return modelEval(tree['right'], inData)
    复制代码

           使用方法非常简单,将树和要分类的样本传递进去就可以了。如果是模型树就将分类函数 treeForeCast 的第三个参数改为modelTreeEval即可。

           这里就不再演示实验具体过程了。

    小结

           1. 选择哪个回归方法,得看哪个方法的相关系数高。(可使用 corrcoef 函数计算)

           2. 树的回归和分类算法其实本质上都属于贪心算法,不断去寻找局部最优解。

           3. 关于回归的讨论就先告一段落,接下来将进入到无监督学习部分。

  • 相关阅读:
    System.arraycopy()的用法?
    Java当中“+=”和“=+”的区别
    jsp FN 标签库的使用方法
    手作编辑画面处理
    mpfu 位编辑处理?
    5/14 自动跟新 位数编集 百分号添加 手作部品。
    jsp 4-14 知识总结
    jstl split 分割字符串?
    aws vpc 知识总结(助理级)
    典型的软件自动化测试框架
  • 原文地址:https://www.cnblogs.com/fujian-code/p/7637722.html
Copyright © 2011-2022 走看看