zoukankan      html  css  js  c++  java
  • 【机器学习实战 第九章】树回归 CART算法的原理与实现

    本文来自《机器学习实战》(Peter Harrington)第九章“树回归”部分,代码使用python3.5,并在jupyter notebook环境中测试通过,推荐clone仓库后run cell all就可以了。

    github地址:https://github.com/gshtime/machinelearning-in-action-python3

    转载请标明原文链接

    1 原理

    CART(Classification and Regression Trees,分类回归树)是决策树算法的一种,这种树构建算法既可以用于分类也可以用于回归。

    它采用一种递归二元分割(recursive binary splitting)的技术,分割方法采用基于最小距离的基尼指数(分类树中)或最小平方残差(回归树中)等方法来估计函数的不纯度,从而将当前的样本集分为两个子样本集,使得生成的的每个非叶子节点都有两个分支。因此,CART算法生成的决策树是结构简洁的二叉树。

    因此,CART的目标是:选择输入变量和那些变量上的分割点,直到创建出适当的树。在这个过程中,使用贪婪算法(greedy algorithm)选择使用哪个输入变量和分割点,以使成本函数(cost function)最小化。

    1.1 CART回归树的原理

    本文主要讲解CART回归树的原理及实现

    现在关注一下回归树的 CART 算法的细节。简要来说,创建一个决策树包含两步:

    1. 把预测器空间,即一系列可能值 (X_1,X_2,...,X_p) 分成 (J) 个不同的且非重叠的区域 (R_1,R_2,...,R_J)

    2. 对进入区域 (R_J) 的每一个样本观测值都进行相同的预测,该预测就是 (R_J) 中训练样本预测值的均值。

    为了创建 (J) 个区域 (R_1,R_2,...,R_J),预测器区域被分为高维度的矩形或盒形。其目的在于通过下列式子找到能够使 (RSS) 最小化的盒形区域 (R_1,R_2,...,R_J)

    [sum_{j=1}^{J} sum_{i in R_j} ig(y_i - hat{y}_{R_j}ig)^2 ]

    其中,(hat{y}_{R_j}) 即是第 (j) 个盒形中训练观测的平均预测值。

    鉴于这种空间分割在计算上是不可行的,因此我们常使用贪婪方法(greedy approach)来划分区域,叫做递归二元分割(recursive binary splitting)。

    它是贪婪的(greedy),这是因为在创建树过程中的每一步骤,最佳分割都会在每个特定步骤选定,而不是对未来进行预测,并选取一个将会在未来步骤中出现且有助于创建更好的树的分割。注意所有的划分区域 (R_j,∀j∈[1,J]) 都是矩形。为了进行递归二元分割,首先选取预测器 (X_j) (即数据集中的一个特征)和切割点 (s)(即该特征下某一个数据的值),递归遍历该特征下面所有的值作为二元分割的切割点,对预测器(特征)下的数据分割到不同的区域,即:(R_1(j,s)=ig{ X|Xj < s ig} 和 R_2(j,s)=ig{ X|Xj ge s ig}),使得代价函数RSS得到最大程度的下降。从数学上讲,就是要寻找区域数J(我理解为叶节点数量)和分割点s,使分割后的代价函数最小化:
    ​​

    [sum_{i: x_i in R_1(j,s)} ig(y_i-hat{y}_{R_1}ig)^2 + sum_{i: x_i in R_2(j,s)} ig(y_i-hat{y}_{R_2}ig)^2 ]

    其中 (hat{y}_{R_1}) 为区域 (R_1(j,s)) 中观察样本的平均预测值,(hat{y}_{R_2}) 为区域 (R_2(j,s)) 的观察样本预测均值。这一过程不断重复以搜寻最好的预测器和切分点,并进一步分隔数据以使每一个子区域内的 RSS 最小化。然而,我们不会分割整个预测器空间,我们只会分割一个或两个前面已经认定的区域。这一过程会一直持续,直到达到停止准则,例如我们可以设定停止准则为每一个区域最多包含 m 个观察样本。一旦我们创建了区域 (R_1、R_2、...、R_J),给定一个测试样本,我们就可以用该区域所有训练样本的平均预测值来预测该测试样本的值。

    2 代码

    2.1 CART回归树实现

    代码比较长,不知道cnblogs中是否能折叠,为了方便复制,还是都放在一块吧,github中的代码是分开的,有需要可以去看。

    原书regTrees.py部分的代码如下

    # -*- coding: utf-8 -*-
    import numpy as np
    
    def loadDataSet(fileName):
        '''
        read the data file using TAB as separator,and store the data in float list
        '''
        dataMat = []
        fr = open(fileName)
        for line in fr.readlines():
            curLine = line.strip().split('	')
            fltLine = list(map(float, curLine))
            dataMat.append(fltLine)
        return dataMat
    
    def binSplitDataSet(dataSet, feature, value):
        mat0 = dataSet[np.nonzero(dataSet[:,feature]  > value)[0],:]
        mat1 = dataSet[np.nonzero(dataSet[:,feature] <= value)[0],:]
        return mat0, mat1
    
    def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
        feat, val =chooseBestSplit(dataSet, leafType, errType, ops)
        if feat == None: return val
        retTree = {}
        retTree['spInd'] = feat
        retTree['spVal'] = val
        lSet, rSet = binSplitDataSet(dataSet, feat, val)
        retTree['left'] = createTree(lSet, leafType, errType, ops)
        retTree['right'] = createTree(rSet, leafType, errType, ops)
        return retTree
    
    def regLeaf(dataSet):
        return np.mean(dataSet[:, -1])
    
    def regErr(dataSet):
        return np.var(dataSet[:,-1]) * np.shape(dataSet)[0]
    
    #choose the best feature and splitting value
    def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
        tolS = ops[0] #tolerant value of S decilne
        tolN = ops[1] #min number of samples to be splitted
        if len(set(dataSet[:,-1].T.tolist()[0])) == 1:
            return None, leafType(dataSet)
        m,n = np.shape(dataSet)
        S = errType(dataSet)
        bestS = np.inf;
        bestIndex= 0;
        bestValue = 0
        for featIndex in range(n-1):
            for splitVal in set(dataSet[:,featIndex].T.tolist()[0]):
                mat0, mat1 = binSplitDataSet(dataSet,featIndex, splitVal)
                if(np.shape(mat0)[0] < tolN) or (np.shape(mat1)[0] < tolN): continue
                newS = errType(mat0) + errType(mat1)
                if newS < bestS:
                    bestIndex = featIndex
                    bestValue = splitVal
                    bestS = newS
        #verdict whether the deciline of S reach the tolS or not
        if (S - bestS) < tolS:
            return None, leafType(dataSet)
        mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
        if(np.shape(mat0)[0] < tolN) or (np.shape(mat1)[0] < tolN):
            return None, leafType(dataSet)
        return bestIndex, bestValue
    
    def isTree(obj):
        return (type(obj).__name__=='dict')
    
    def getMean(tree):
        if isTree(tree['right']): tree['right'] = getMean(tree['right'])
        if isTree(tree['left']) : tree['left']  = getMean(tree['left'])
        return (tree['left'] + tree['right'])/2.0
    
    def prune(tree, testData):
        if np.shape(testData)[0] == 0: return getMean(tree)
        if(isTree(tree['right']) or isTree(tree['left'])):
            lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
        if isTree(tree['left']): tree['left'] = prune(tree['left'], lSet)
        if isTree(tree['right']):tree['right']= prune(tree['right'],rSet)
        if not isTree(tree['left']) and not isTree(tree['right']):
            lSet, rSet = binSplitDataSet(testData, tree['spInd'],tree['spVal'])
            errorNoMerge = np.sum(np.power(lSet[:,-1] - tree['left'], 2)) + np.sum(np.power(rSet[:,-1] - tree['right'], 2))
            treeMean = (tree['left']+tree['right'])/2.0
            errorMerge = np.sum(np.power(testData[:,-1] - treeMean, 2))
            if errorMerge < errorNoMerge:
                print("merging")
                return treeMean
            else: return tree
        else: return tree
    
    def linearSolve(dataSet):
        m,n = np.shape(dataSet)
        X = np.mat(np.ones((m,n)))
        Y = np.mat(np.ones((m,1)))
        X[:,1:n] = dataSet[:,0:n-1]
        Y = dataSet[:,-1]
        xTx = X.T*X
        if np.linalg.det(xTx) == 0.0:
            raise NameError("This matrix is singular, cannot do inverse,
    try increasing the second value of ops")
        ws = xTx.I * (X.T*Y)
        return ws, X , Y
    def modelLeaf(dataSet):
        ws, X, Y = linearSolve(dataSet)
        return ws
    
    def modelErr(dataSet):
        ws, X, Y = linearSolve(dataSet)
        yHat = X * ws
        return np.sum(np.power(Y - yHat, 2))
    
    def regTreeEval(model, inDat):
        return float(model)
    
    def modelTreeEval(model, inDat):
        n = np.shape(inDat)[1]
        X = np.mat(np.ones((1,n+1)))
        X[:,1:n+1] = inDat
        return float(X*model)
    
    def treeForecast(tree, inData, modelEval=regTreeEval):
        if not isTree(tree): return modelEval(tree, inData)
        if inData[tree['spInd']] > tree['spVal']:
            if isTree(tree['left']):
                return treeForecast(tree['left'], inData, modelEval)
            else:
                return modelEval(tree['left'], inData)
        else:
            if isTree(tree['right']):
                return treeForecast(tree['right'], inData, modelEval)
            else:
                return modelEval(tree['right'], inData)
    
    def createForecast(tree, testData, modelEval=regTreeEval):
        m = len(testData)
        yHat = np.mat(np.zeros((m,1)))
        for i in range(m):
            yHat[i,0] = treeForecast(tree, np.mat(testData[i]), modelEval)
        return yHat
    

    2.2 使用python3的tkinter库创建GUI

    python 2to3

    原书的代码使针对python2.x环境构建的,在python2.x中应该import Tkinter,而在python3.x中,应该import tkinter才能正常导入Tkinter库

    代码

    # -*- coding:utf-8 -*-
    import tkinter as tk
    
    import matplotlib
    
    matplotlib.use('TkAgg')
    from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
    from matplotlib.figure import Figure
    
    def reDraw(tolS, tolN):
        reDraw.f.clf()
        reDraw.a = reDraw.f.add_subplot(111)
        
        if chkBtnVar.get():
            if tolN < 2: tolN = 2
            myTree = createTree(reDraw.rawDat, modelLeaf, modelErr, (tolS,tolN))
            yHat = createForecast(myTree, reDraw.testDat, modelTreeEval)
        else:
            myTree = createTree(reDraw.rawDat, ops=(tolS, tolN))
            yHat = createForecast(myTree, reDraw.testDat)
            
        reDraw.a.scatter(reDraw.rawDat[:,0].A, reDraw.rawDat[:,1].A, s=5)
        reDraw.a.plot(reDraw.testDat, yHat, linewidth=2.0)
        
        reDraw.canvas.show()
          
    def getInput():
        try:
            tolN = int(tolNentry.get())
        except:
            tolN = 10
            print("enter Integet for tolN")
            tolNentry.delete(0, tk.END)
            tolNentry.insert(0, "10")
        try:
            tolS = float(tolSentry.get())
        except:
            tolS = 1.0
            print("enter Integet for tolS")
            tolNentry.delete(0, tk.END)
            tolNentry.insert(0, "1.0")
        return tolN, tolS
    
    def drawNewTree():
        tolN, tolS = getInput()
        reDraw(tolS, tolN)
    
    root = tk.Tk()
    
    #tk.Label(root, text="Plot Place Holder").grid(row=0, columnspan=3)
    
    reDraw.f = Figure(figsize=(5,4), dpi=100)
    reDraw.canvas = FigureCanvasTkAgg(reDraw.f, master=root)
    reDraw.canvas.show()
    reDraw.canvas.get_tk_widget().grid(row=0, columnspan=3)
    
    
    tk.Label(root, text="tolN").grid(row=1, column=0)
    tolNentry = tk.Entry(root)
    tolNentry.grid(row=1, column=1)
    tolNentry.insert(0, '10')
    tk.Label(root, text="tolS").grid(row=2, column=0)
    tolSentry = tk.Entry(root)
    tolSentry.grid(row=2, column=1)
    tolSentry.insert(0, '1.0')
    tk.Button(root, text="ReDraw", command=drawNewTree).grid(row=1,column=2, rowspan=3)
    
    chkBtnVar = tk.IntVar()
    chkBtn = tk.Checkbutton(root, text="Model Tree", variable= chkBtnVar)
    chkBtn.grid(row=3, column=0, columnspan=2)
    
    reDraw.rawDat = np.mat(loadDataSet('./data/sine.txt'))
    reDraw.testDat = np.arange(np.min(reDraw.rawDat[:,0]), np.max(reDraw.rawDat[:,0]), 0.01)
    
    reDraw(1.0, 10)
    
    root.mainloop()
    

    测试代码

    测试的代码都在书里,我的github仓库里也有,有空我再放这儿吧

    注意

    有时候运行tkinter的时候,可能python会无限地崩溃,可以试一下重装matplotlib库来解决

    参考资料

    1. https://zhuanlan.zhihu.com/p/28217071
      这是一篇文章的中文翻译,推荐大家看看该文章的英文原文,这篇文章我觉得写得很棒,对了解CART有很大帮助,文中给出了借助sklearn库的CART实现方法,比较简单,另外作者给了其他决策树算法的文章链接。总之很推荐。
    2. http://blog.csdn.net/u014568921/article/details/45082197

    写得比较仓促,自己也在理解和学习中,如果有不对的地方,还请多多指正。现在时间晚了,回头有空把这篇文章写得更全一点

  • 相关阅读:
    洛谷 P3868 [TJOI2009]猜数字
    洛谷 P2661 信息传递
    hdu 5418 Victor and World
    洛谷 P5024 保卫王国
    洛谷 P2470 [SCOI2007]压缩
    双栈排序 2008年NOIP全国联赛提高组(二分图染色)
    理想的正方形 HAOI2007(二维RMQ)
    10.23NOIP模拟题
    疫情控制 2012年NOIP全国联赛提高组(二分答案+贪心)
    图论模板
  • 原文地址:https://www.cnblogs.com/toone/p/7392832.html
Copyright © 2011-2022 走看看