zoukankan      html  css  js  c++  java
  • 树回归 CART算法

    线性回归创建的预测模型需要拟合所有的样本点,在数据拥有众多特征并且特征之间关系十分复杂时,构建全局模型太难,而且,生活中很多问题是非线性的,不可能使用全局线性模型来拟合任何数据。

    一种可行的方法是把数据集切分成很多分易建模的数据,然后利用线性回归技术来建模。如果首次切分后仍然难以拟合线性模型就继续切分。这种切分方式下,树结构和回归法就相当有用。

    CART算法:分类回归树,既可用于分类也可用于回归。

    第三章使用的决策树构建算法是ID3,每次选取当前最佳的特征来分割数据。属于贪心算法,不考虑能否达到全局最优。而且容易造成过拟合、不能直接处理连续型特征,只有事先将连续型特征转换成离散型,才能使用ID3算法。

    而使用二元切分法则易于对树构建过程进行调整以处理连续型特征。如果特征值大于给定值就走左子树,小于给定值就走右子树。

    CART算法的实现代码:

    from numpy import *
    def loadDataSet(filename):
        dataMat=[]
        f=open(filename)
        for line in f.readlines():
            curLine=line.strip().split('	')
            floatLine=list(map(float,curLine))
            dataMat.append(floatLine)
        return dataMat
    def binSplitDataSet(dataSet,feature,value):
        mat0=dataSet[nonzero(dataSet[:,feature]>value)[0],:]
        mat1=dataSet[nonzero(dataSet[:,feature]<=value)[0],:]
        return mat0,mat1
    def createTree(dataSet,leafType=regLeaf,errType=regErr,ops=(1,4)):
        feat, val = chooseBestSplit(dataSet, leafType, errType, ops)
        if feat == None: return val
        retTree = {}
        retTree['spInd'] = feat
        retTree['spVal'] = val
        lSet, rSet = binSplitDataSet(dataSet, feat, val)
        retTree['left'] = createTree(lSet, leafType, errType, ops)
        retTree['right'] = createTree(rSet, leafType, errType, ops)
        return retTree
    

    chooseBestSplit()函数暂未实现。


     

     将CART算法用于回归:
    回归树假设叶子节点是常数值。用平方误差的总值(总方差)来计算连续型数值的混乱程度。总方差等于均方差乘以数据集中样本点的个数。

    chooseBestSplit():给定某个误差计算方法,该函数会找到数据集上最佳的二元切分方式。还要确定什么时候停止切分,一旦停止切分就会生成一个叶子节点。所以:用最佳方式切分数据集和生成相应的叶节点。

    伪代码:

    对每个特征:
        对每个特征值:
            将数据集切分成两份
            计算切分后的误差
            如果当前误差小于当前最小误差,将当前切分设定为最佳切分并更新最小误差
        返回最佳切分的特征和阈值
    

     切分函数的实现:

    def regLeaf(dataSet):   #负责生成叶节点,当chooseBestSplit函数确定不再对数据进行切分时,将调用regLeaf函数得到叶节点的模型
        return mean(dataSet[:,-1])  #在回归树中,此模型就是目标变量的均值
    
    def regErr(dataSet):    # 误差估计函数,计算目标变量的平方误差,需要返回总误差,即为均方误差乘以数据集中样本个数
        return var(dataSet[:, -1]) * shape(dataSet)[0]
    
    def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1, 4)): #ops为用户指定的参数,用于控制函数的停止时机
        tolS = ops[0]  # 容许的误差下降值
        tolN = ops[1]  # 切分的最少样本数
        if len(set(dataSet[:, -1].T.tolist()[0])) == 1:  # 统计不同剩余特征值得数目,如果数目为一,就不需要再切分而直接返回
            return None, leafType(dataSet)
        else:
            m, n = shape(dataSet)
            S = errType(dataSet)    #误差
            bestS = inf     #最小误差
            bestIndex = 0
            bestValue = 0
            for featIndex in range(n - 1):  # 对所有特征进行遍历,找到最佳切分方式。最佳切分就是使得切分后能达到最低误差的切分
                # for splitVal in set(dataSet[:, featIndex]):  # 遍历某个特征的所有特征值
                for splitVal in set((dataSet[:, featIndex].T.A.tolist())[0]):
                    mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)  # 按照某个特征的某个值将数据切分成两个数据子集
                    if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):  # 如果某个子集行数不大于tolN,也不应该切分
                        continue
                    newS = errType(mat0) + errType(mat1)  # 新误差由切分后的两个数据子集组成的误差
                    if newS < bestS:  # 判断新切分能否降低误差
                        bestIndex = featIndex
                        bestValue = splitVal
                        bestS = newS
            if (S - bestS) < tolS:  # 如果误差降低不大则退出
                return None, leafType(dataSet)
            mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
            if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):  # 如果切分出的数据集很小则退出
                return None, leafType(dataSet)
            return bestIndex, bestValue
    

     regLeaf():负责生成叶节点,即求当前数据集目标值的平均值作为回归预测值。当chooseBestSplit()确定不再对数据进行切分时,将调用regLeaf()函数来得到叶节点的模型。回归树中,该模型是目标变量的均值。

    regErr():误差估计函数。在给定数据集上计算目标变量的平方误差。

    chooseBestSplit():构建回归树的核心函数。目的是找到数据的最佳二元切分方式。如果找不到好的二元切分,就返回None并同时调用regLeaf()方法来产生叶节点。

    运行代码:

    if __name__=='__main__':
        myMat=loadDataSet('ex00.txt')
        myMat=mat(myMat)
        result=createTree(myMat)
        print(result)
    

     输出为:

    {'spInd': 0, 'spVal': 0.48813, 'right': -0.04465028571428572, 'left': 1.0180967672413792}
    

    只有两个叶节点,对照下面的散点图可以看出,在数据0.48813左侧的数据,回归预测值为-0.04465,右侧预测值为1.018。

    数据集散点图:

    因为数据集简单,所以得到的回归树也简单。

    更换数据集测试:

    if __name__=='__main__':
        myMat2=loadDataSet('ex2.txt')
        myMat2=mat(myMat2)
        myTree = createTree(myMat2, ops=(0, 1))
        print(myTree)
    

     输出:

    {'spInd': 0, 'spVal': 0.499171, 'right': {'spInd': 0, 'spVal': 0.457563, 'right': {'spInd': 0, 'spVal': 0.455761, 'right': {'spInd': 0, 'spVal': 0.126833, 'right': {'spInd': 0, 'spVal': 0.124723, 'right': {'spInd': 0, 'spVal': 0.085111, 'right': {'spInd': 0, 'spVal': 0.084661, 'right': {'spInd': 0, 'spVal': 0.080061, 'right': {'spInd': 0, 'spVal': 0.068373, 'right': {'spInd': 0, 'spVal': 0.061219, 'right': {'spInd': 0, 'spVal': 0.044737, 'right': {'spInd': 0, 'spVal': 0.028546, 'right': {'spInd': 0, 'spVal': 0.000256, 'right': 9.668106, 'left': -8.377094}, 'left': {'spInd': 0, 'spVal': 0.039914, 'right': 11.220099, 'left': 3.855393}}, 'left': {'spInd': 0, 'spVal': 0.053764, 'right': -13.731698, 'left': {'spInd': 0, 'spVal': 0.055862, 'right': -3.131497, 'left': 6.695567}}}, 'left': -15.160836}, 'left': {'spInd': 0, 'spVal': 0.079632, 'right': 29.420068, 'left': 2.229873}}, 'left': -24.132226}, 'left': 37.820659}, 'left': {'spInd': 0, 'spVal': 0.108801, 'right': {'spInd': 0, 'spVal': 0.10796, 'right': {'spInd': 0, 'spVal': 0.085873, 'right': -10.137104, 'left': -1.293195}, 'left': -16.106164}, 'left': {'spInd': 0, 'spVal': 0.11515, 'right': 13.795828, 'left': -1.402796}}}, 'left': 22.891675}, 'left': {'spInd': 0, 'spVal': 0.130626, 'right': -39.524461, 'left': {'spInd': 0, 'spVal': 0.382037, 'right': {'spInd': 0, 'spVal': 0.335182, 'right': {'spInd': 0, 'spVal': 0.324274, 'right': {'spInd': 0, 'spVal': 0.309133, 'right': {'spInd': 0, 'spVal': 0.131833, 'right': 22.478291, 'left': {'spInd': 0, 'spVal': 0.138619, 'right': -29.087463, 'left': {'spInd': 0, 'spVal': 0.156067, 'right': {'spInd': 0, 'spVal': 0.13988, 'right': 7.336784, 'left': 7.557349}, 'left': {'spInd': 0, 'spVal': 0.166765, 'right': {'spInd': 0, 'spVal': 0.156273, 'right': 0.225886, 'left': {'spInd': 0, 'spVal': 0.164134, 'right': -27.405211, 'left': {'spInd': 0, 'spVal': 0.166431, 'right': -6.512506, 'left': -14.740059}}}, 'left': {'spInd': 0, 'spVal': 0.193282, 'right': {'spInd': 0, 'spVal': 0.176523, 'right': 0.946348, 'left': 18.208423}, 'left': {'spInd': 0, 'spVal': 0.211633, 'right': {'spInd': 0, 'spVal': 0.202161, 'right': {'spInd': 0, 'spVal': 0.199903, 'right': -3.372472, 'left': -1.983889}, 'left': {'spInd': 0, 'spVal': 0.203993, 'right': -22.379119, 'left': {'spInd': 0, 'spVal': 0.206207, 'right': -12.619036, 'left': -8.332207}}}, 'left': {'spInd': 0, 'spVal': 0.228473, 'right': {'spInd': 0, 'spVal': 0.222271, 'right': {'spInd': 0, 'spVal': 0.218321, 'right': {'spInd': 0, 'spVal': 0.217214, 'right': -3.958752, 'left': 1.410768}, 'left': -9.255852}, 'left': {'spInd': 0, 'spVal': 0.2232, 'right': 15.501642, 'left': 19.425158}}, 'left': {'spInd': 0, 'spVal': 0.25807, 'right': {'spInd': 0, 'spVal': 0.228628, 'right': -2.266273, 'left': {'spInd': 0, 'spVal': 0.228751, 'right': -30.812912, 'left': {'spInd': 0, 'spVal': 0.232802, 'right': 1.222318, 'left': -20.425137}}}, 'left': {'spInd': 0, 'spVal': 0.284794, 'right': {'spInd': 0, 'spVal': 0.273863, 'right': {'spInd': 0, 'spVal': 0.264926, 'right': {'spInd': 0, 'spVal': 0.264639, 'right': 2.557923, 'left': 5.280579}, 'left': -9.457556}, 'left': 35.623746}, 'left': {'spInd': 0, 'spVal': 0.300318, 'right': {'spInd': 0, 'spVal': 0.297107, 'right': {'spInd': 0, 'spVal': 0.295993, 'right': {'spInd': 0, 'spVal': 0.290749, 'right': -14.391613, 'left': -14.988279}, 'left': -1.798377}, 'left': -18.051318}, 'left': 8.814725}}}}}}}}}}, 'left': {'spInd': 0, 'spVal': 0.310956, 'right': -49.939516, 'left': {'spInd': 0, 'spVal': 0.318309, 'right': -27.605424, 'left': -13.189243}}}, 'left': {'spInd': 0, 'spVal': 0.32889, 'right': 39.783113, 'left': {'spInd': 0, 'spVal': 0.331364, 'right': -1.290825, 'left': {'spInd': 0, 'spVal': 0.3349, 'right': 18.97665, 'left': 2.768225}}}}, 'left': {'spInd': 0, 'spVal': 0.370042, 'right': {'spInd': 0, 'spVal': 0.35679, 'right': {'spInd': 0, 'spVal': 0.350725, 'right': {'spInd': 0, 'spVal': 0.350065, 'right': {'spInd': 0, 'spVal': 0.342761, 'right': {'spInd': 0, 'spVal': 0.342155, 'right': {'spInd': 0, 'spVal': 0.3417, 'right': -23.547711, 'left': -16.930416}, 'left': -31.584855}, 'left': -1.319852}, 'left': -40.086564}, 'left': {'spInd': 0, 'spVal': 0.351478, 'right': -0.461116, 'left': -19.526539}}, 'left': -32.124495}, 'left': {'spInd': 0, 'spVal': 0.378965, 'right': {'spInd': 0, 'spVal': 0.373501, 'right': -8.228297, 'left': {'spInd': 0, 'spVal': 0.377383, 'right': 5.241196, 'left': 13.583555}}, 'left': -29.007783}}}, 'left': {'spInd': 0, 'spVal': 0.388789, 'right': {'spInd': 0, 'spVal': 0.385021, 'right': 24.816941, 'left': 21.578007}, 'left': {'spInd': 0, 'spVal': 0.437652, 'right': {'spInd': 0, 'spVal': 0.412516, 'right': {'spInd': 0, 'spVal': 0.403228, 'right': {'spInd': 0, 'spVal': 0.391609, 'right': 3.001104, 'left': -1.729244}, 'left': -26.419289}, 'left': {'spInd': 0, 'spVal': 0.418943, 'right': 44.161493, 'left': {'spInd': 0, 'spVal': 0.426711, 'right': -21.594268, 'left': {'spInd': 0, 'spVal': 0.428582, 'right': 15.224266, 'left': 19.745224}}}}, 'left': {'spInd': 0, 'spVal': 0.454312, 'right': {'spInd': 0, 'spVal': 0.446196, 'right': -5.108172, 'left': {'spInd': 0, 'spVal': 0.451087, 'right': -28.724685, 'left': -20.360067}}, 'left': {'spInd': 0, 'spVal': 0.454375, 'right': 3.043912, 'left': 9.841938}}}}}}}, 'left': -34.044555}, 'left': {'spInd': 0, 'spVal': 0.465561, 'right': {'spInd': 0, 'spVal': 0.463241, 'right': 17.171057, 'left': 30.051931}, 'left': {'spInd': 0, 'spVal': 0.467383, 'right': {'spInd': 0, 'spVal': 0.46568, 'right': -23.777531, 'left': -9.712925}, 'left': {'spInd': 0, 'spVal': 0.483803, 'right': 5.224234, 'left': {'spInd': 0, 'spVal': 0.487381, 'right': 27.729263, 'left': {'spInd': 0, 'spVal': 0.487537, 'right': 5.149336, 'left': 11.924204}}}}}}, 'left': {'spInd': 0, 'spVal': 0.729397, 'right': {'spInd': 0, 'spVal': 0.640515, 'right': {'spInd': 0, 'spVal': 0.613004, 'right': {'spInd': 0, 'spVal': 0.606417, 'right': {'spInd': 0, 'spVal': 0.513332, 'right': {'spInd': 0, 'spVal': 0.508548, 'right': {'spInd': 0, 'spVal': 0.508542, 'right': 96.403373, 'left': 93.292829}, 'left': 101.075609}, 'left': {'spInd': 0, 'spVal': 0.533511, 'right': {'spInd': 0, 'spVal': 0.51915, 'right': 116.176162, 'left': {'spInd': 0, 'spVal': 0.531944, 'right': 124.795495, 'left': 129.766743}}, 'left': {'spInd': 0, 'spVal': 0.548539, 'right': {'spInd': 0, 'spVal': 0.546601, 'right': {'spInd': 0, 'spVal': 0.537834, 'right': 90.995536, 'left': {'spInd': 0, 'spVal': 0.543843, 'right': 98.36201, 'left': 96.319043}}, 'left': 83.114502}, 'left': {'spInd': 0, 'spVal': 0.553797, 'right': {'spInd': 0, 'spVal': 0.549814, 'right': 137.267576, 'left': 120.857321}, 'left': {'spInd': 0, 'spVal': 0.560301, 'right': 82.903945, 'left': {'spInd': 0, 'spVal': 0.599142, 'right': {'spInd': 0, 'spVal': 0.589806, 'right': {'spInd': 0, 'spVal': 0.582311, 'right': {'spInd': 0, 'spVal': 0.571214, 'right': {'spInd': 0, 'spVal': 0.569327, 'right': 108.435392, 'left': 114.872056}, 'left': 82.589328}, 'left': {'spInd': 0, 'spVal': 0.585413, 'right': 125.295113, 'left': 98.674874}}, 'left': 130.378529}, 'left': 93.521396}}}}}}, 'left': 168.180746}, 'left': {'spInd': 0, 'spVal': 0.623909, 'right': {'spInd': 0, 'spVal': 0.618868, 'right': 76.917665, 'left': 87.181863}, 'left': {'spInd': 0, 'spVal': 0.628061, 'right': {'spInd': 0, 'spVal': 0.624827, 'right': 105.970743, 'left': 117.628346}, 'left': {'spInd': 0, 'spVal': 0.637999, 'right': {'spInd': 0, 'spVal': 0.632691, 'right': 93.645293, 'left': 91.656617}, 'left': 82.713621}}}}, 'left': {'spInd': 0, 'spVal': 0.642373, 'right': 140.613941, 'left': {'spInd': 0, 'spVal': 0.642707, 'right': 82.500766, 'left': {'spInd': 0, 'spVal': 0.665329, 'right': {'spInd': 0, 'spVal': 0.661073, 'right': {'spInd': 0, 'spVal': 0.652462, 'right': 112.715799, 'left': 115.687524}, 'left': 121.980607}, 'left': {'spInd': 0, 'spVal': 0.706961, 'right': {'spInd': 0, 'spVal': 0.698472, 'right': {'spInd': 0, 'spVal': 0.689099, 'right': {'spInd': 0, 'spVal': 0.666452, 'right': {'spInd': 0, 'spVal': 0.665652, 'right': 105.547997, 'left': 120.014736}, 'left': {'spInd': 0, 'spVal': 0.667851, 'right': 92.449664, 'left': {'spInd': 0, 'spVal': 0.680486, 'right': 110.367074, 'left': 112.378209}}}, 'left': 120.521925}, 'left': {'spInd': 0, 'spVal': 0.69892, 'right': 92.470636, 'left': {'spInd': 0, 'spVal': 0.699873, 'right': 115.586605, 'left': {'spInd': 0, 'spVal': 0.70639, 'right': 105.062147, 'left': 106.180427}}}}, 'left': {'spInd': 0, 'spVal': 0.70889, 'right': 135.416767, 'left': {'spInd': 0, 'spVal': 0.716211, 'right': {'spInd': 0, 'spVal': 0.710234, 'right': 108.553919, 'left': 103.345308}, 'left': 110.90283}}}}}}}, 'left': {'spInd': 0, 'spVal': 0.952833, 'right': {'spInd': 0, 'spVal': 0.759504, 'right': {'spInd': 0, 'spVal': 0.740859, 'right': {'spInd': 0, 'spVal': 0.731636, 'right': 73.912028, 'left': 93.773929}, 'left': {'spInd': 0, 'spVal': 0.757527, 'right': 63.549854, 'left': 81.106762}}, 'left': {'spInd': 0, 'spVal': 0.763328, 'right': 115.199195, 'left': {'spInd': 0, 'spVal': 0.769043, 'right': 64.041941, 'left': {'spInd': 0, 'spVal': 0.790312, 'right': {'spInd': 0, 'spVal': 0.786865, 'right': {'spInd': 0, 'spVal': 0.785574, 'right': {'spInd': 0, 'spVal': 0.777582, 'right': 100.838446, 'left': 107.024467}, 'left': 100.598825}, 'left': {'spInd': 0, 'spVal': 0.787755, 'right': 118.642009, 'left': 110.15973}}, 'left': {'spInd': 0, 'spVal': 0.806158, 'right': {'spInd': 0, 'spVal': 0.799873, 'right': {'spInd': 0, 'spVal': 0.798198, 'right': 76.853728, 'left': 91.368473}, 'left': 62.877698}, 'left': {'spInd': 0, 'spVal': 0.815215, 'right': {'spInd': 0, 'spVal': 0.811602, 'right': {'spInd': 0, 'spVal': 0.811363, 'right': 112.981216, 'left': 99.841379}, 'left': 118.319942}, 'left': {'spInd': 0, 'spVal': 0.833026, 'right': {'spInd': 0, 'spVal': 0.823848, 'right': {'spInd': 0, 'spVal': 0.819722, 'right': 70.054508, 'left': 59.342323}, 'left': 76.723835}, 'left': {'spInd': 0, 'spVal': 0.841547, 'right': {'spInd': 0, 'spVal': 0.838587, 'right': 134.089674, 'left': 115.669032}, 'left': {'spInd': 0, 'spVal': 0.841625, 'right': 60.552308, 'left': {'spInd': 0, 'spVal': 0.944221, 'right': {'spInd': 0, 'spVal': 0.85497, 'right': {'spInd': 0, 'spVal': 0.84294, 'right': 95.893131, 'left': {'spInd': 0, 'spVal': 0.847219, 'right': 76.240984, 'left': 89.20993}}, 'left': {'spInd': 0, 'spVal': 0.936524, 'right': {'spInd': 0, 'spVal': 0.934853, 'right': {'spInd': 0, 'spVal': 0.925782, 'right': {'spInd': 0, 'spVal': 0.910975, 'right': {'spInd': 0, 'spVal': 0.901444, 'right': {'spInd': 0, 'spVal': 0.901421, 'right': {'spInd': 0, 'spVal': 0.892999, 'right': {'spInd': 0, 'spVal': 0.888426, 'right': {'spInd': 0, 'spVal': 0.872199, 'right': {'spInd': 0, 'spVal': 0.866451, 'right': {'spInd': 0, 'spVal': 0.856421, 'right': 107.166848, 'left': 94.402102}, 'left': 111.552716}, 'left': {'spInd': 0, 'spVal': 0.883615, 'right': {'spInd': 0, 'spVal': 0.872883, 'right': 95.887712, 'left': 95.348184}, 'left': {'spInd': 0, 'spVal': 0.885676, 'right': 108.045948, 'left': 94.896354}}}, 'left': 82.436686}, 'left': {'spInd': 0, 'spVal': 0.900699, 'right': {'spInd': 0, 'spVal': 0.896683, 'right': 107.00162, 'left': 109.188248}, 'left': 100.133819}}, 'left': 87.300625}, 'left': {'spInd': 0, 'spVal': 0.908629, 'right': 118.513475, 'left': 106.814667}}, 'left': {'spInd': 0, 'spVal': 0.912161, 'right': 85.005351, 'left': {'spInd': 0, 'spVal': 0.915263, 'right': 96.71761, 'left': 92.074619}}}, 'left': 115.753994}, 'left': 65.548418}, 'left': {'spInd': 0, 'spVal': 0.937766, 'right': 119.949824, 'left': 100.120253}}}, 'left': {'spInd': 0, 'spVal': 0.948822, 'right': 69.318649, 'left': {'spInd': 0, 'spVal': 0.949198, 'right': 105.752508, 'left': {'spInd': 0, 'spVal': 0.952377, 'right': 73.520802, 'left': 100.649591}}}}}}}}}}}}}, 'left': {'spInd': 0, 'spVal': 0.965969, 'right': {'spInd': 0, 'spVal': 0.956951, 'right': {'spInd': 0, 'spVal': 0.953902, 'right': 130.92648, 'left': {'spInd': 0, 'spVal': 0.954711, 'right': 100.935789, 'left': 82.016541}}, 'left': {'spInd': 0, 'spVal': 0.958512, 'right': 135.837013, 'left': {'spInd': 0, 'spVal': 0.960398, 'right': 123.559747, 'left': 112.386764}}}, 'left': {'spInd': 0, 'spVal': 0.968621, 'right': 98.648346, 'left': 86.399637}}}}}
    

     散点图:

    得到的树很复杂,改变ops元组的值:

    if __name__=='__main__':
        myMat2 = loadDataSet('ex2.txt')
        myMat2 = mat(myMat2)
        myTree = createTree(myMat2, ops=(10000, 4))
        print(myTree)
    

     输出:

    {'spInd': 0, 'spVal': 0.499171, 'left': 101.35815937735848, 'right': -2.637719329787234}
    

     也可以得到仅有两个叶节点的树。


    树剪枝:
    一棵树如果节点过多,表明该模型可能对数据进行了过拟合。通过降低决策树的复杂度来避免过拟合的过程称为“剪枝”。

    在函数chooseBestSplit()中的提前终止条件,实际上是“预剪枝”操作,预剪枝操作对于参数ops元组非常敏感,难以获得有效的回归树。

    后剪枝:利用测试集对数进行剪枝。由于不需要用户指定参数,后剪枝是一种更理想化的剪枝方法。

    首先将数据集划分为训练集和测试集。先使用训练集构建出一棵足够复杂的树便于剪枝。然后从上到下找到叶节点,用测试集来判断这些叶节点合并能不能降低测试误差,如果可以的话就合并。

    伪代码如下:

    基于已有的树切分测试数据:
        如果存在任一子集是一棵树,则在该子集递归剪枝过程
        计算将当前两个叶子节点合并后的误差
        计算不合并的误差
        如果合并会降低误差则合并
    

     回归树剪枝函数prune():

    def isTree(obj):  # 测试输入变量是否是一棵树,返回布尔型的结果,用于判断当前处理的节点是否是叶节点
        return (type(obj).__name__ == "dict")
    def getMean(tree):  # 递归函数,从上到下遍历树直到叶节点为止。如果找到两个叶节点则计算它们的平均值。该函数对树进行塌陷处理
        if isTree(tree["right"]):
            tree["right"] = getMean(tree["right"])
        if isTree(tree["left"]):
            tree["left"] = getMean(tree["left"])
        return (tree["left"] + tree["right"]) / 2.0
    
    def prune(tree, testData):  #参数:待剪枝的树与剪枝所需的测试数据
        if shape(testData)[0] == 0:     #没有测试数据则对树进行塌陷处理
            return getMean(tree)
        if (isTree(tree['right']) or isTree(tree['left'])):  #
            lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
        if isTree(tree['left']):
            tree['left'] = prune(tree['left'], lSet)
        if isTree(tree['right']):
            tree['right'] = prune(tree['right'], rSet)
        if not isTree(tree['left']) and not isTree(tree['right']):
            lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
            errorNoMerge = sum(power(lSet[:, -1] - tree['left'], 2)) + sum(power(rSet[:, -1] - tree['right'], 2))
            treeMean = (tree['left'] + tree['right']) / 2.0
            errorMerge = sum(power(testData[:, -1] - treeMean, 2))
            if errorMerge < errorNoMerge:
                print("融合")
                return treeMean
            else:
                return tree
        else:
            return tree
    

     isTree():测试输入变量是否是一棵树,返回布尔值的结果。用于判断当前处理的节点是不是叶子节点。

    getMean():递归函数,从上到下遍历树直到叶节点。如果找到两个叶节点就返回其平均值。该函数对树进行塌陷处理。

    prune():参数为待剪枝的树和剪枝所需的测试数据集。

    测试:

    if __name__=='__main__':
        myMat2=loadDataSet('ex2.txt')
        myMat2=mat(myMat2)
        myTree = createTree(myMat2, ops=(0, 1))
        myDat2Test = loadDataSet("ex2test.txt")
        myMat2Test = mat(myDat2Test)
        result=prune(myTree, myMat2Test)
        print(result)
    

     输出:

    融合
    融合
    融合
    融合
    融合
    融合
    融合
    融合
    融合
    融合
    融合
    融合
    融合
    融合
    融合
    融合
    融合
    融合
    融合
    融合
    融合
    融合
    融合
    融合
    融合
    融合
    融合
    融合
    融合
    融合
    融合
    融合
    融合
    融合
    融合
    融合
    融合
    融合
    融合
    融合
    融合
    融合
    融合
    融合
    {'left': {'left': {'left': {'left': 92.5239915, 'spInd': 0, 'spVal': 0.965969, 'right': {'left': {'left': {'left': 112.386764, 'spInd': 0, 'spVal': 0.960398, 'right': 123.559747}, 'spInd': 0, 'spVal': 0.958512, 'right': 135.837013}, 'spInd': 0, 'spVal': 0.956951, 'right': 111.2013225}}, 'spInd': 0, 'spVal': 0.952833, 'right': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': 96.41885225, 'spInd': 0, 'spVal': 0.948822, 'right': 69.318649}, 'spInd': 0, 'spVal': 0.944221, 'right': {'left': {'left': 110.03503850000001, 'spInd': 0, 'spVal': 0.936524, 'right': {'left': 65.548418, 'spInd': 0, 'spVal': 0.934853, 'right': {'left': 115.753994, 'spInd': 0, 'spVal': 0.925782, 'right': {'left': {'left': 94.3961145, 'spInd': 0, 'spVal': 0.912161, 'right': 85.005351}, 'spInd': 0, 'spVal': 0.910975, 'right': {'left': {'left': 106.814667, 'spInd': 0, 'spVal': 0.908629, 'right': 118.513475}, 'spInd': 0, 'spVal': 0.901444, 'right': {'left': 87.300625, 'spInd': 0, 'spVal': 0.901421, 'right': {'left': {'left': 100.133819, 'spInd': 0, 'spVal': 0.900699, 'right': 108.094934}, 'spInd': 0, 'spVal': 0.892999, 'right': {'left': 82.436686, 'spInd': 0, 'spVal': 0.888426, 'right': {'left': 98.54454949999999, 'spInd': 0, 'spVal': 0.872199, 'right': 106.16859550000001}}}}}}}}}, 'spInd': 0, 'spVal': 0.85497, 'right': {'left': {'left': 89.20993, 'spInd': 0, 'spVal': 0.847219, 'right': 76.240984}, 'spInd': 0, 'spVal': 0.84294, 'right': 95.893131}}}, 'spInd': 0, 'spVal': 0.841625, 'right': 60.552308}, 'spInd': 0, 'spVal': 0.841547, 'right': 124.87935300000001}, 'spInd': 0, 'spVal': 0.833026, 'right': {'left': 76.723835, 'spInd': 0, 'spVal': 0.823848, 'right': {'left': 59.342323, 'spInd': 0, 'spVal': 0.819722, 'right': 70.054508}}}, 'spInd': 0, 'spVal': 0.815215, 'right': {'left': 118.319942, 'spInd': 0, 'spVal': 0.811602, 'right': {'left': 99.841379, 'spInd': 0, 'spVal': 0.811363, 'right': 112.981216}}}, 'spInd': 0, 'spVal': 0.806158, 'right': 73.49439925}, 'spInd': 0, 'spVal': 0.790312, 'right': {'left': 114.4008695, 'spInd': 0, 'spVal': 0.786865, 'right': 102.26514075}}, 'spInd': 0, 'spVal': 0.769043, 'right': 64.041941}, 'spInd': 0, 'spVal': 0.763328, 'right': 115.199195}, 'spInd': 0, 'spVal': 0.759504, 'right': 78.08564325}}, 'spInd': 0, 'spVal': 0.729397, 'right': {'left': {'left': {'left': {'left': {'left': {'left': {'left': 110.90283, 'spInd': 0, 'spVal': 0.716211, 'right': {'left': 103.345308, 'spInd': 0, 'spVal': 0.710234, 'right': 108.553919}}, 'spInd': 0, 'spVal': 0.70889, 'right': 135.416767}, 'spInd': 0, 'spVal': 0.706961, 'right': {'left': {'left': {'left': {'left': 106.180427, 'spInd': 0, 'spVal': 0.70639, 'right': 105.062147}, 'spInd': 0, 'spVal': 0.699873, 'right': 115.586605}, 'spInd': 0, 'spVal': 0.69892, 'right': 92.470636}, 'spInd': 0, 'spVal': 0.698472, 'right': {'left': 120.521925, 'spInd': 0, 'spVal': 0.689099, 'right': {'left': 101.91115275, 'spInd': 0, 'spVal': 0.666452, 'right': 112.78136649999999}}}}, 'spInd': 0, 'spVal': 0.665329, 'right': {'left': 121.980607, 'spInd': 0, 'spVal': 0.661073, 'right': {'left': 115.687524, 'spInd': 0, 'spVal': 0.652462, 'right': 112.715799}}}, 'spInd': 0, 'spVal': 0.642707, 'right': 82.500766}, 'spInd': 0, 'spVal': 0.642373, 'right': 140.613941}, 'spInd': 0, 'spVal': 0.640515, 'right': {'left': {'left': {'left': {'left': 82.713621, 'spInd': 0, 'spVal': 0.637999, 'right': {'left': 91.656617, 'spInd': 0, 'spVal': 0.632691, 'right': 93.645293}}, 'spInd': 0, 'spVal': 0.628061, 'right': {'left': 117.628346, 'spInd': 0, 'spVal': 0.624827, 'right': 105.970743}}, 'spInd': 0, 'spVal': 0.623909, 'right': 82.04976400000001}, 'spInd': 0, 'spVal': 0.613004, 'right': {'left': 168.180746, 'spInd': 0, 'spVal': 0.606417, 'right': {'left': {'left': {'left': {'left': {'left': {'left': 93.521396, 'spInd': 0, 'spVal': 0.599142, 'right': {'left': 130.378529, 'spInd': 0, 'spVal': 0.589806, 'right': {'left': 111.9849935, 'spInd': 0, 'spVal': 0.582311, 'right': {'left': 82.589328, 'spInd': 0, 'spVal': 0.571214, 'right': {'left': 114.872056, 'spInd': 0, 'spVal': 0.569327, 'right': 108.435392}}}}}, 'spInd': 0, 'spVal': 0.560301, 'right': 82.903945}, 'spInd': 0, 'spVal': 0.553797, 'right': 129.0624485}, 'spInd': 0, 'spVal': 0.548539, 'right': {'left': 83.114502, 'spInd': 0, 'spVal': 0.546601, 'right': {'left': 97.3405265, 'spInd': 0, 'spVal': 0.537834, 'right': 90.995536}}}, 'spInd': 0, 'spVal': 0.533511, 'right': {'left': {'left': 129.766743, 'spInd': 0, 'spVal': 0.531944, 'right': 124.795495}, 'spInd': 0, 'spVal': 0.51915, 'right': 116.176162}}, 'spInd': 0, 'spVal': 0.513332, 'right': {'left': 101.075609, 'spInd': 0, 'spVal': 0.508548, 'right': {'left': 93.292829, 'spInd': 0, 'spVal': 0.508542, 'right': 96.403373}}}}}}}, 'spInd': 0, 'spVal': 0.499171, 'right': {'left': {'left': {'left': {'left': {'left': 8.53677, 'spInd': 0, 'spVal': 0.487381, 'right': 27.729263}, 'spInd': 0, 'spVal': 0.483803, 'right': 5.224234}, 'spInd': 0, 'spVal': 0.467383, 'right': {'left': -9.712925, 'spInd': 0, 'spVal': 0.46568, 'right': -23.777531}}, 'spInd': 0, 'spVal': 0.465561, 'right': {'left': 30.051931, 'spInd': 0, 'spVal': 0.463241, 'right': 17.171057}}, 'spInd': 0, 'spVal': 0.457563, 'right': {'left': -34.044555, 'spInd': 0, 'spVal': 0.455761, 'right': {'left': {'left': {'left': {'left': {'left': -4.1911745, 'spInd': 0, 'spVal': 0.437652, 'right': {'left': {'left': {'left': {'left': 19.745224, 'spInd': 0, 'spVal': 0.428582, 'right': 15.224266}, 'spInd': 0, 'spVal': 0.426711, 'right': -21.594268}, 'spInd': 0, 'spVal': 0.418943, 'right': 44.161493}, 'spInd': 0, 'spVal': 0.412516, 'right': {'left': -26.419289, 'spInd': 0, 'spVal': 0.403228, 'right': 0.6359300000000001}}}, 'spInd': 0, 'spVal': 0.388789, 'right': 23.197474}, 'spInd': 0, 'spVal': 0.382037, 'right': {'left': {'left': {'left': -29.007783, 'spInd': 0, 'spVal': 0.378965, 'right': {'left': {'left': 13.583555, 'spInd': 0, 'spVal': 0.377383, 'right': 5.241196}, 'spInd': 0, 'spVal': 0.373501, 'right': -8.228297}}, 'spInd': 0, 'spVal': 0.370042, 'right': {'left': -32.124495, 'spInd': 0, 'spVal': 0.35679, 'right': {'left': -9.9938275, 'spInd': 0, 'spVal': 0.350725, 'right': -26.851234812500003}}}, 'spInd': 0, 'spVal': 0.335182, 'right': {'left': 22.286959625, 'spInd': 0, 'spVal': 0.324274, 'right': {'left': {'left': -20.3973335, 'spInd': 0, 'spVal': 0.310956, 'right': -49.939516}, 'spInd': 0, 'spVal': 0.309133, 'right': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': 8.814725, 'spInd': 0, 'spVal': 0.300318, 'right': {'left': -18.051318, 'spInd': 0, 'spVal': 0.297107, 'right': {'left': -1.798377, 'spInd': 0, 'spVal': 0.295993, 'right': {'left': -14.988279, 'spInd': 0, 'spVal': 0.290749, 'right': -14.391613}}}}, 'spInd': 0, 'spVal': 0.284794, 'right': {'left': 35.623746, 'spInd': 0, 'spVal': 0.273863, 'right': {'left': -9.457556, 'spInd': 0, 'spVal': 0.264926, 'right': {'left': 5.280579, 'spInd': 0, 'spVal': 0.264639, 'right': 2.557923}}}}, 'spInd': 0, 'spVal': 0.25807, 'right': {'left': {'left': -9.601409499999999, 'spInd': 0, 'spVal': 0.228751, 'right': -30.812912}, 'spInd': 0, 'spVal': 0.228628, 'right': -2.266273}}, 'spInd': 0, 'spVal': 0.228473, 'right': 6.099239}, 'spInd': 0, 'spVal': 0.211633, 'right': {'left': -16.42737025, 'spInd': 0, 'spVal': 0.202161, 'right': -2.6781805}}, 'spInd': 0, 'spVal': 0.193282, 'right': 9.5773855}, 'spInd': 0, 'spVal': 0.166765, 'right': {'left': {'left': {'left': -14.740059, 'spInd': 0, 'spVal': 0.166431, 'right': -6.512506}, 'spInd': 0, 'spVal': 0.164134, 'right': -27.405211}, 'spInd': 0, 'spVal': 0.156273, 'right': 0.225886}}, 'spInd': 0, 'spVal': 0.156067, 'right': {'left': 7.557349, 'spInd': 0, 'spVal': 0.13988, 'right': 7.336784}}, 'spInd': 0, 'spVal': 0.138619, 'right': -29.087463}, 'spInd': 0, 'spVal': 0.131833, 'right': 22.478291}}}}}, 'spInd': 0, 'spVal': 0.130626, 'right': -39.524461}, 'spInd': 0, 'spVal': 0.126833, 'right': {'left': 22.891675, 'spInd': 0, 'spVal': 0.124723, 'right': {'left': {'left': 6.196516, 'spInd': 0, 'spVal': 0.108801, 'right': {'left': -16.106164, 'spInd': 0, 'spVal': 0.10796, 'right': {'left': -1.293195, 'spInd': 0, 'spVal': 0.085873, 'right': -10.137104}}}, 'spInd': 0, 'spVal': 0.085111, 'right': {'left': 37.820659, 'spInd': 0, 'spVal': 0.084661, 'right': {'left': -24.132226, 'spInd': 0, 'spVal': 0.080061, 'right': {'left': 15.824970500000001, 'spInd': 0, 'spVal': 0.068373, 'right': {'left': -15.160836, 'spInd': 0, 'spVal': 0.061219, 'right': {'left': {'left': {'left': 6.695567, 'spInd': 0, 'spVal': 0.055862, 'right': -3.131497}, 'spInd': 0, 'spVal': 0.053764, 'right': -13.731698}, 'spInd': 0, 'spVal': 0.044737, 'right': 4.091626}}}}}}}}}}}
    View Code

     虽然合并了很多叶节点,但剪枝后的树没有像预期的那样剪枝成两部分。说明后剪枝可能不如预剪枝有效。可以同时使用两种剪枝方式。


    模型树:把叶子节点设定为分段线性函数。利用数生成算法对数据切分,且每份切分数据容易被线性模型表示。该算法的关键在于误差的计算。

    对于给定的数据集,应该先用线性的模型对它拟合,然后计算真是的目标值与模型预测值之间的差值,再将这些差值的平方求和就得到了所需要的误差。

    模型树的叶节点生成函数:

    def linearSolve(dataSet):
        m, n = shape(dataSet)
        X = mat(ones((m, n)))  #第一列仍为1
        Y = mat(ones((m, 1)))
        X[:, 1:n] = dataSet[:, 0:n - 1]
        # print('X:',X)
        Y = dataSet[:, -1]  # 将X,Y中的数据格式化
        # print('Y:',Y)
        xTx = X.T * X
        if linalg.det(xTx) == 0.0:
            raise NameError("此矩阵不可逆。")
            # ws = linalg.pinv(xTx) * (X.T * Y)
        ws = xTx.I * (X.T * Y)
        return ws, X, Y
    
    def modelLeaf(dataSet):  # 当数据不再需要切分的时候它负责生成叶节点模型
        ws, X, Y = linearSolve(dataSet)
        return ws
    def modelErr(dataSet):
        ws, X, Y = linearSolve(dataSet)
        yHat = X * ws
        return sum(power(Y - yHat, 2))
    

     数据集散点图如下:

     测试:

    myMat=mat(loadDataSet('exp2.txt'))
        plotPoint(myMat)
        myTree=createTree(myMat,modelLeaf,modelErr,(1,10))
        print(myTree)
    

     输出结果:

    {'spInd': 0, 'spVal': 0.285477, 'right': matrix([[3.46877936],
            [1.18521743]]), 'left': matrix([[1.69855694e-03],
            [1.19647739e+01]])}
    

     将数据集从x=0.285477分开,分别用两段线性模型来拟合。


     树回归与标准回归的比较:相关系数

    用树回归进行预测的代码:包括回归树和模型树两种树

    def regTreeEval(model, inDat):  #回归树效果评估
        return float(model)
    
    def modelTreeEval(model, inDat):    #模型树效果评估
        n = shape(inDat)[1]
        X = mat(ones((1, n + 1)))
        X[:, 1:n + 1] = inDat
        return float(X * model)
    
    def treeForeCast(tree, inData, modelEval=regTreeEval):
        if not isTree(tree):
            return modelEval(tree, inData)  # 如果输入单个数据或行向量,返回一个浮点值
        else:
            if inData[tree["spInd"]] > tree["spVal"]:
                if isTree(tree["left"]):
                    return treeForeCast(tree["left"], inData, modelEval)
                else:
                    return modelEval(tree["left"], inData)
            else:
                if isTree(tree["right"]):
                    return treeForeCast(tree["right"], inData, modelEval)
                else:
                    return modelEval(tree["right"], inData)
    def createForeCast(tree, testData, modelEval=regTreeEval):  #测试不同回归树的效果
        m = len(testData)
        yHat = mat(zeros((m, 1)))
        for i in range(m):
            yHat[i, 0] = treeForeCast(tree, mat(testData[i]), modelEval)  # 多次调用treeForeCast函数,将结果以列的形式放到yHat变量中
        return yHat
    

     因为代码中已经含有标准线性回归函数(linearSolve),所以不必重新写其生成代码。

    测试:

    if __name__=='__main__':
        trainMat = mat(loadDataSet("bikeSpeedVsIq_train.txt"))
        testMat = mat(loadDataSet("bikeSpeedVsIq_test.txt"))
        myTree = createTree(trainMat, ops=(1, 20))
        yHat = createForeCast(myTree, testMat[:, 0])
        print("回归树的相关系数:", corrcoef(yHat, testMat[:, -1], rowvar=0)[0, 1])
    
        myTree = createTree(trainMat, modelLeaf, modelErr, (1, 20))
        yHat = createForeCast(myTree, testMat[:, 0], modelTreeEval)
        print("模型树的相关系数:", corrcoef(yHat, testMat[:, -1], rowvar=0)[0, 1])
    
        ws, X, Y = linearSolve(trainMat)
        print("线性回归系数:", ws)
        for i in range(shape(testMat)[0]):
            yHat[i] = testMat[i, 0] * ws[1, 0] + ws[0, 0]
        print("线性回归模型的相关系数:", corrcoef(yHat, testMat[:, -1], rowvar=0)[0, 1])
    

     输出:

    回归树的相关系数: 0.964085231822215
    模型树的相关系数: 0.9760412191380629
    线性回归系数: [[37.58916794]
     [ 6.18978355]]
    线性回归模型的相关系数: 0.9434684235674766
    

     相关系数越接近1越好,所以,模型树>回归树>标准线性回归。

  • 相关阅读:
    hibernate持久化框架
    spring之AOP
    spring之bean
    spring之IOC
    pdf文件工具typora
    vsCode写vue项目一键生成.vue模板
    微信小程序瀑布流
    小程序接入阿拉丁
    小程序引入背景图片不显示问题解决
    Mac OS下使用rz和sz
  • 原文地址:https://www.cnblogs.com/zhhy236400/p/9955049.html
Copyright © 2011-2022 走看看