zoukankan      html  css  js  c++  java
  • 决策树CART的python实现

    CART算法只做二元切分,因此每个树节点包含待切分的特征,待切分的特征值,左子树,右子树。

    import numpy as np
    
    class treeNode(object):
        def __init__(self, feat, val, right, left):
            featureToSplitOn = feat
            valueOfSplit = val
            rightBranch = right
            leftBranch = left

    给定特征和特征值的数据切分函数

    1 def binSplitDataSet(dataSet, feature, value):
    2     mat0 = dataSet[np.nonzero(dataSet[:, feature] > value)[0], :]
    3     mat1 = dataSet[np.nonzero(dataSet[:, feature] <= value)[0], :]
    4     return mat0, mat1

    树构建函数,包含4个参数,数据集和其他3个可选参数。这些可选参数决定了树的类型:leafType给出建立叶节点的函数;errType代表误差计算函数;ops是一个包含树构建所需其他参数的元组。

    createTree是一个递归函数。该函数首先尝试将数据集分成两个部分,切分由函数chooseBestSplit()完成。如果满足停止条件,chooseBestSplit将返回None和某类模型的值。如果构建的是回归树,该模型是一个常数。如果是模型树,模型是一个线性方程。如果不满足停止条件,chooseBestSplit将创建一个新的Python字典并将数据集分成两份,在这两份数据集上将分别继续递归调用createTree函数。

     1 def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1, 4)):
     2     feat, val = chooseBestSplit(dataSet, leafType, errType, ops)
     3     if feat == None:
     4         return val
     5     retTree = {}
     6     retTree["spInd"] = feat
     7     retTree["spVal"] = val
     8     lSet, rSet = binSplitDataSet(dataSet, feat, val)
     9     retTree["left"] = createTree(lSet, leafType, errType, ops)
    10     retTree["right"] = createTree(rSet, leafType, errType, ops)
    11     return retTree

    1.回归树

    chooseBestSplit只需完成两件事:用最佳方式切分数据集和生成相应地叶节点。其包含:dataSet、leafType、errType和ops。leafType是对创建叶节点的函数的引用,errType是计算总方差函数的引用。chooseBestSplit函数的目标是找到数据集切分的最佳位置。它遍历所有特征及其所有特征值来找到使误差最小化的切分阈值。

     1 def regLeaf(dataSet):
     2     return np.mean(dataSet[:, -1])   # 数据最后一列为预测值
     3 
     4 def regErr(dataSet):
     5     return np.var(dataSet[:, -1]) * np.shape(dataSet)[0]
     6 
     7 def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1, 4)):
     8     tolS = ops[0]
     9     tolN = ops[1]
    10     if len(set(dataSet[:, -1].tolist())) == 1:
    11         return None, leafType(dataSet)
    12     m, n = np.shape(dataSet)
    13     S = errType(dataSet)
    14     bestS = np.inf
    15     bestIndex = 0
    16     bestValue = 0
    17     for featIndex in range(n-1):
    18         for splitVal in set(dataSet[:, featIndex]):
    19             mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
    20             if (np.shape(mat0)[0] < tolN) or (np.shape(mat1)[0] < tolN):
    21                 continue
    22             newS = errType(mat0) + errType(mat1)
    23             if newS < bestS:
    24                 bestIndex = featIndex
    25                 bestValue = splitVal
    26                 bestS = newS
    27     if (S - bestS) < tolS:    # 如果误差减少不大则退出
    28         return None, leafType(dataSet)
    29     mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
    30     if (np.shape(mat0)[0] < tolN) or (np.shape(mat1)[0] < tolN):    # 如果切分出的数据集很小则退出
    31         return None, leafType(dataSet)
    32     return bestIndex, bestValue

    2.模型树

    用树来对数据建模,除了把叶节点简单的设定为常数值之外,还有一种方法是把叶节点设定为分段线性函数,这里所谓的分段线性是指模型由多个线性片段组成。

    为了找到最佳切分应该怎样计算误差呢?对于给定的数据集,应该先用线性的模型对它进行拟合,然后计算真实的目标值与模型预测值间的差值。最后将这些差值的平方求和就得到了所需的误差。

     1 def linearSolve(dataSet):
     2     m, n = np.shape(dataSet)
     3     X = np.mat(np.ones((m, n)))
     4     Y = np.mat(np.ones((m, 1)))
     5     X[:, 1:n] = dataSet[:, 0:n-1]
     6     Y = dataSet[:, -1]
     7     xTx = x.T * X
     8     if np.linalg.det(xTx) == 0:
     9         raise NameError('This matrix is singular, cannot do inverse, try increasing the second value of ops')
    10     ws = xTx.I * (X.T * Y)
    11     return ws, X, Y
    12 
    13 def modelLeaf(dataSet):
    14     ws, X, Y = linearSolve(dataSet)
    15     return ws
    16 
    17 def modelErr(dataSet):
    18     ws, X, Y = linearSolve(dataSet)
    19     yHat = X * ws
    20     return np.sum(np.power(Y - yHat, 2))
  • 相关阅读:
    MyPHPdumpTool:MySQL 数据库备份处理方案
    sdcvx:轻量级的词典工具
    Fedora中你用GNOME还是KDE?
    Linux/GNU课程
    Fireflix:便利 Flickr 用户的 Firefox 扩展
    gtkchtheme
    recordMyDesktop:录制你的 Linux 桌面
    Fedora 8.0 NS2.33拆卸手记
    办理selinux招致无法进入零碎
    ie在Ubuntu8.04下的安装进程
  • 原文地址:https://www.cnblogs.com/ningjing213/p/10838476.html
Copyright © 2011-2022 走看看