zoukankan      html  css  js  c++  java
  • 关于回归树的创建和剪枝

      之前对于树剪枝一直感到很神奇;最近参考介绍手工写了一下剪枝代码,才算理解到底什么是剪枝。

      首先要明白回归树作为预测的模式(剪枝是针对回归树而言),其实是叶子节点进行预测;所以在使用回归树进行预测的时候,本质都是在通过每层(每个层代表一个属性)的值的大于和小于来作为分值,进行二叉树的遍历。最后预测值其实叶子节点中左值或者右值;注意这里的叶子结点也是一个结构体,对于非叶子节点而言,他的左右值是一棵树,但是对于叶子结点而言,左右值则是一个单一的数值。

      那么剪枝的原始就是找到叶子节点,如上图所示的特征C和特征E,然后取左右值的均值,合并(merge)为一个节点。比如低于特征C,就是取值5.5,作为B树的左节点,这样特征C这个节点就被减掉了。

      但是在剪枝的时候注意一定要和原始场景进行比较,未剪枝前的偏差和剪枝后偏差做一个比较,看看到底哪个更优秀;如果剪枝后MSE值反而更加大了,就不要价值了。这里偏差的计算值是sum(power(yHat- y, 2))来进行比较即可。

      下面的就是剪枝的python实现:

     1 # 所谓剪枝即使遍历到叶子结点,然后看一下作为预测值的叶子结点,合并左右节点(即取左右子树平均数)为一个点
     2 # 但是需要比较一下合并之后的偏差和合并前的偏差,如果合并之后的方差变小了,则剪枝(取合并值),反之则保持原状
     3 def prune(tree, testData):
     4     m, n = shape(testData)
     5     # 如果测试在分类(分割)过程,某一类数据为0
     6     if(m == 0): return getMean(tree)
     7     # 下面一大段其实都是在做这一件事情:深入都叶子结点
     8     # 1. 只要左右子树中有一颗不是叶子结点,那么就以当前节点的spIndex以及spValue为分割(分类)点,对testData进行二元分类
     9     # 获得的是二元分类的数据集left set和right set
    10     if(isTree(tree["left"]) or isTree(tree["right"])):
    11         lset, rset = binSplitDataset(testData, tree["spIndex"], tree["spValue"])
    12     # 2. 继续处理不是叶子结点左右子树,对其进行递归prune(本质就是要深入到叶子结点为止)
    13     if(isTree(tree["left"])): 
    14         tree["left"] = prune(tree["left"], lset)
    15     if(isTree(tree["right"])): 
    16         tree["right"] = prune(tree["right"], rset)
    17     
    18     # 左右子树都是叶子节点了
    19     if(not isTree(tree["left"]) and not isTree(tree["right"])):
    20         # 那么就以当前叶子节点的spIndex以及spValue为分割(分类)点,对testData进行二元分类
    21         lset, rset = binSplitDataset(testData, tree["spIndex"], tree["spValue"])
    22         # 计算测试数据集和预测值(叶子结点)之间的方差,剪枝前的偏差
    23         errorNotMerge = sum(power(lset[:, -1] - tree["left"], 2)) + sum(power(rset[:, -1] - tree["right"],2))
    24         treeMean = (tree["left"] + tree["right"]) / 2.0
    25         # 测试数据全集和树均值(预测值)之间的方差,剪枝后偏差
    26         errorMerge = sum(power(testData[:, -1] - treeMean, 2))
    27         # 看看谁的方差小,如果测试数据全集和树均值的方差小,返回的是树均值(叶子结点)的均值
    28         if(errorMerge < errorNotMerge):
    29             #print("errorMerge < errorNotMerge, treeMean is: ")
    30             #print(treeMean)
    31             return treeMean
    32         # 如果叶子节点(预测值)的和真实值之间的方差比较小,则返回的树,不需要剪枝
    33         else:
    34             #print("errorMerge > errorNotMerge, [tree] is: ")
    35             #print(tree)
    36             return tree
    37     # 说明叶子结点剪枝效果不明显,不需要剪枝
    38     else:
    39         return tree
    40             

      那么再汇过来,如何构建一个回归树呢?

      构建回归树有几个条件,首先要有样本数据,叶子节点的计算方式(regLeaf),以及计算一个数据集的偏差的公式(regErr);

    1 from numpy import mean
    2 
    3 # 数据集中y值的均值
    4 def regLeaf(dataset):
    5     return mean(dataset[:, -1])
    6 
    7 # 数据集中y值的方差和
    8 def regErr(dataset):
    9     return var(dataset[:, -1]) * shape(dataset)[0]

      有了这三者之后,就可以进行构建树了。构建树的时候,首先将会选择一个区分度最好的特征以及特征值,做样本分割,然后基于分割后的样本分别构建左子树和右子树,这是一个递归的过程,发生变化的样本,以及基于变化的样本产生的新的分割特征以及特征值,这个递归过程一直到样本数据不再可分为止,此时获得就是一个value,这个就是叶子结点的left/right值(非叶子节点left/right仍然是一棵树)。

     1 # 获取最好的分割信息,这里包括分割的特征以及特征值,然后对数据进行分割,在以分割后数据为基础继续进行继续创建树,一直到数据无法再分割
     2 # (feature)为none为止。
     3 def createTree(dataset, leafType=regLeaf, errorType=regErr, ops=(1, 4)):
     4     feature, value = chooseBsetSplit(dataset, leafType, errorType, ops)
     5     # left/right值直接就是数字(不再是树了)
     6     if(feature == None):
     7         return value
     8     retTree = {}
     9     retTree["spIndex"] = feature
    10     retTree["spValue"] = value
    11     # chooseBsetSplit其实应该一并把mat0和mat1返回,这样这里就不需要再计算了。
    12     # 但是后来看了一下代码,返现该函数里面有的返回分支里面是没有mat0和mat1,所以这里再计算一下也是说的通的。
    13     lset, rset = bindSplitDataset(dataset, feature, value)
    14     retTree["left"] = createTree(lset, leafType, errorType, ops)
    15     retTree["right"] = createTree(rset, leafType, errorType, ops)
    16     
    17     return retTree

      下面的代码就是获取最佳区分特征和特征值的实现

     1 # 寻找最好的区分特征;为了能够找到需要遍历所有的特征,以及所有的特征值,然后以该特征值做分割,获取两个矩阵
     2 # 计算两个矩阵的方差,不断选出方差小的作为bestIndex以及bestValue;最后将bestIndex对应的方差和原始矩阵
     3 # 方差进行比较,如果发现最佳区分特征对应的两分割矩阵方差明显小,并且两个矩阵的样本数量都不是十分小;
     4 # 则说明该特征是OK的
     5 
     6 # 返回的feature信息可能是None,代表该节点就是叶子结点中left/right的值,该函数
     7 def chooseBsetSplit(dataset, leafType=regLeaf, errorType=regErr, ops=(1, 4)):
     8     # 可容忍的偏差,在程序开始的时候,通过errorType来计算一下dataset的y值的方差和;然后用dataset的方差
     9     # 和最好区分度的方差和做减法,如果发现差值比这个tolS要小,那么说明这次指定特征是失败的;理想的差值是要大于tols
    10     # 方差一定要比原始数据小到一定程度,这次属性指定才有意义。
    11     tolS = ops[0]
    12     tolN = ops[1] # 特征划分的样本的阈值,如果一分为二后,任何一个分类样本数少于这个阈值,这次划分就取消
    13     # 为什么==1就要退出?
    14     if(len(set(dataset[:, -1].T.tolist()[0])) == 1):
    15         #print("len(set(dataset[:, -1].T.tolist()[0])) == 1, return None feature")
    16         return None, leafType(dataset)
    17     m, n = shape(dataset)
    18     # 注意这里errorType其实就是参数,这里参数就是一个函数,默认是regErr
    19     S =errorType(dataset)
    20     # 初始化best*
    21     bestS = inf
    22     bestIndex = 0
    23     bestValue = 0
    24     iterate_num = n-1
    25     #print("iterate_num: " + str(iterate_num))
    26     # 遍历所有的特征(最后一列是结果,跳过)
    27     for featureIndex in range(iterate_num):
    28         #print("++++++++++++++++++++++ %d turns +++++++++++++++++++++++" % (featureIndex))
    29         # 遍历该特征的所有特征值
    30         for splitValue in set(dataset[:, featureIndex].A.flatten().tolist()):
    31             # 在所有训练样本上面(dataset)对于该特征,大于该特征值,小于特征值分别做数据分割,获得两个矩阵
    32             mat0, mat1 = bindSplitDataset(dataset, featureIndex, splitValue)
    33             # 如果分割的特征矩阵任意一个的样本数<tolN,那么将会跳过该特征的处理,经过分割一定要达到一定的样本数才有意义
    34             # 任意一个矩阵的样本数少说明该特征的区分度不高
    35             if(shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):
    36                 #print("shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN; splitValue: %f, shape(mat0)[0]: %d, (shape(mat1)[0]: %d, tolN: %d" % (splitValue, shape(mat0)[0], shape(mat1)[0], tolN))
    37                 continue
    38             #print("*************** one ok **********************")
    39             # 和leafType一样,都是参数类型为函数,计算方差和
    40             newS = errorType(mat0) + errorType(mat1)
    41             # 如果方差小于bestS,则用当前的方差以及特征信息做替换;到此可以看到目标就是找到区分度高并且方差小的特征,作为最好
    42             # 区分特征
    43             if(newS < bestS):
    44                 bestIndex = featureIndex
    45                 bestS = newS
    46                 bestValue = splitValue
    47     # 如果S值和bestS值之差小于tolS;参见tolS的注释。
    48     if(S -bestS) < tolS:
    49         #print("(S -bestS) < tolS, return feature NULL, S: %s, bestS: %s, tolS: %s" % ( str(S), str(bestS), str(tolS)))
    50         return None, leafType(dataset)
    51     mat0, mat1 = bindSplitDataset(dataset, bestIndex, bestValue)
    52     # 这里的判断有意义吗?在循环体中其实已经做了这个判断,如果不满足也不会成为bestIndex,bestvalue;
    53     if(shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):
    54         print("shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN")
    55         return None, leafType(dataset)
    56     
    57     return bestIndex, bestValue

     

  • 相关阅读:
    StringUtils工具类的使用
    struts2 文件上传和下载,以及部分源代码解析
    ios开发之猜数字游戏
    从epoll构建muduo-12 多线程入场
    POJ3009 Curling 2.0(DFS)
    IOS-4-面试题1:黑马程序猿IOS面试题大全
    Android-Universal-Image-Loader载入图片
    《UNIX环境高级编程》读书笔记 —— 文件 I/O
    畅通project再续 HDU杭电1875 【Kruscal算法 || Prim】
    轻松学习之Linux教程四 神器vi程序编辑器攻略
  • 原文地址:https://www.cnblogs.com/xiashiwendao/p/10507098.html
Copyright © 2011-2022 走看看