zoukankan      html  css  js  c++  java
  • 《机器学习实战》笔记——树回归

    线性回归的缺陷:

      创建模型是需要拟合所有的样本(除了局部加权线性回归),当数据特征多且关系复杂时,显得太笨拙

    树回归:

      将数据集分成多分易建模的数据,然后在这些易于建模的小数据集上利用线性回归建模。树回归采用的是二元划分法,所以只可能产生二叉树。

    CART算法

      全称:classification and regression trees,分类回归树。在每个叶节点上使用各自的均值做预测,每个叶节点上包含单个值。建树期间,为了防止树的过拟合,需要使用到树剪枝技术

    模型树算法:

      在每个叶节点上都构建出一个线性模型,即每个叶节点上包含一个线性方程。建树期间,需要调参,所以还回介绍Python中的Tkinter模块建立GUI。

      1 # _*_ coding:utf-8 _*_
      2 
      3 # 9-1 CART算法的代码实现
      4 from numpy import *
      5 
      6 def loadDataSet(fileName):
      7     dataMat = []
      8     fr = open(fileName)
      9     for line in fr.readlines():
     10         curLine = line.strip().split('	')
     11         fltLine = map(float, curLine)  # 每行映射成浮点数
     12         dataMat.append(fltLine)
     13     return dataMat
     14 
     15 # 大于value的数据集划分到左树,小于等于value的数据集划分到右树
     16 def binSplitDataSet(dataSet, feature, value):
     17     mat0 = dataSet[nonzero(dataSet[:, feature] > value)[0], :]
     18     mat1 = dataSet[nonzero(dataSet[:, feature] <= value)[0], :]
     19     return mat0, mat1
     20 
     21 
     22 def regLeaf(dataSet):
     23     return mean(dataSet[:, -1])
     24 
     25 
     26 def regErr(dataSet):
     27     return var(dataSet[:, -1]) * shape(dataSet)[0]
     28 
     29 
     30 
     31 # 9-2 回归树的划分函数
     32 # leafType是对创建叶节点的函数的引用
     33 # errType是对总方差计算函数的引用
     34 # ops是一个用户定义的参数构成的元祖,用于控制停止的时机,第一个值是容许的误差下降值,第二个是切分的最少样本数
     35 def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1, 4)):
     36     tolS = ops[0]
     37     tolN = ops[1]
     38     if len(set(dataSet[:, -1].T.tolist()[0])) == 1:    # 勘误,书上多加了[0]
     39         return None, leafType(dataSet)
     40     m, n = shape(dataSet)
     41     S = errType(dataSet)
     42     bestS = inf
     43     bestIndex = 0
     44     bestValue = 0
     45     for featIndex in range(n - 1):
     46         for splitVal in set((dataSet[:, featIndex].T.A.tolist())[0]):
     47             mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
     48             if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue
     49             newS = errType(mat0) + errType(mat1)
     50             if newS < bestS:
     51                 bestIndex = featIndex
     52                 bestValue = splitVal
     53                 bestS = newS
     54     if (S - bestS) < tolS:
     55         return None, leafType(dataSet)
     56     mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
     57     if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):
     58         return None, leafType(dataSet)
     59     return bestIndex, bestValue  # 返回最好的特征和特征划分的依据值
     60 
     61 # 建CART树
     62 # step1 通过chooseBestSplit函数找到最佳划分特征feat以及划分用的标准val
     63 # step2 若feat为空,证明已经到了叶子节点,返回val;若非空,则递归创建左右子树,返回存储左右子树的字典引用
     64 def createTree(dataSet, leafType=regLeaf, errType=regErr,ops=(1, 4)):
     65     feat, val = chooseBestSplit(dataSet, leafType, errType, ops)
     66     if feat == None: return val
     67     retTree = {}
     68     retTree['spInd'] = feat
     69     retTree['spVal'] = val
     70     lSet, rSet = binSplitDataSet(dataSet, feat, val)
     71     retTree['left'] = createTree(lSet, leafType, errType, ops)
     72     retTree['right'] = createTree(rSet, leafType, errType, ops)
     73     return retTree
     74 
     75 
     76 def isTree(obj):
     77     return (type(obj).__name__ == 'dict')
     78 
     79 
     80 def getMean(tree):
     81     if isTree(tree['right']): tree['right'] = getMean(tree['right'])
     82     if isTree(tree['left']): tree['left'] = getMean(tree['left'])
     83     return (tree['left'] + tree['right']) / 2.0
     84 
     85 # 后剪枝(建树之后进行)
     86 # 准备工作:用训练集建树,构建的树要足够大,方便剪枝
     87 # 步骤:如果左右树任意一棵是树,则在该子集递归剪枝过程;否则,计算并比较两个叶节点合并后的误差和不合并的误差,若误差变小,则合并
     88 def prune(tree, testData):
     89     if shape(testData)[0] == 0: return getMean(tree)
     90     if (isTree(tree['right']) or isTree(tree['left'])):
     91         lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
     92     if isTree(tree['left']): tree['left'] = prune(tree['left'], lSet)
     93     if isTree(tree['right']): tree['right'] = prune(tree['right'], rSet)
     94     if not isTree(tree['left']) and not isTree(tree['right']):
     95         lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
     96         errorNoMerge = sum(power(lSet[:, -1] - tree['left'], 2)) + 
     97                        sum(power(rSet[:, -1] - tree['right'], 2))
     98         treeMean = (tree['left'] + tree['right']) / 2.0
     99         errorMerge = sum(power(testData[:, -1] - treeMean, 2))
    100         if errorMerge < errorNoMerge:
    101             print ("merging")
    102             return treeMean
    103         else:
    104             return tree
    105     else:
    106         return tree
    107 
    108 # 9-4 模型树的叶节点生成函数
    109 def linearSolve(dataSet):
    110     m,n = shape(dataSet)
    111     X = mat(ones((m,n)))
    112     Y = mat(ones((m,1)))
    113     X[:,1:n] = dataSet[:,0:n-1]
    114     Y = dataSet[:,-1]
    115     xTx = X.T * X
    116     if linalg.det(xTx) == 0.0:
    117         raise NameError('This matrix is singular, cannot do inverse,
    try increasing the second value of ops')
    118         # 当程序出现错误,python会自动引发异常,也可以通过raise显示地引发异常。一旦执行了raise语句,raise后面的语句将不能执行。
    119     ws = xTx.I * (X.T * Y)
    120     return ws, X, Y
    121 
    122 
    123 def modelLeaf(dataSet):
    124     ws, X, Y = linearSolve(dataSet)
    125     return ws
    126 
    127 
    128 def modelErr(dataSet):
    129     ws, X, Y = linearSolve(dataSet)
    130     yHat = X * ws
    131     return sum(power(Y - yHat, 2))
    132 
    133 myMat2 = mat(loadDataSet('exp2.txt'))
    134 modelTrees = createTree(myMat2,modelLeaf,modelErr,(1,10))
    135 print modelTrees
  • 相关阅读:
    js setTimeout深度递归后完成回调
    [Err]1267
    YII数据库操作中打印sql
    Creating a web server in pure C(c/c++ 写web server)
    lighttpd 介绍及安装
    HDU 1003 Max Sum
    2014-8-10 掉落不简单
    最全SpringMVC具体演示样例实战教程
    Android 之 资源文件的介绍及使用
    我的创业劲儿,无可阻挡-JAVA学院张孝伟
  • 原文地址:https://www.cnblogs.com/DianeSoHungry/p/7097002.html
Copyright © 2011-2022 走看看