七、代码实现(python)
以下代码来自Peter Harrington《Machine Learing in Action》
本例代码实现算法5,生成最小二乘回归树。
代码如下(保存为CART.py):
# -- coding: utf-8 -- from numpy import * def loadDataSet(fileName): # 获取训练集 dataMat = [] fr = open(fileName) for line in fr.readlines(): curLine = line.strip().split(' ') fltLine = map(float,curLine) dataMat.append(fltLine) return dataMat def binSplitDataSet(dataSet, feature, value): # 该函数接收3个参数,数据集、第几个特征(切分变量)、划分条件(切分点),根据选择的特征和划分条件将数据分成两个区域 mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0],:][0] mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0],:][0] return mat0,mat1 def regLeaf(dataSet): # 获取数据集dataSet最后一列的平均值 return mean(dataSet[:,-1]) def regErr(dataSet): # 根据式(4)计算数据集dataSet的平方误差 # var用于计算方差 return var(dataSet[:,-1]) * shape(dataSet)[0] def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)): # 该函数用于寻找对于数据集dataSet的最好切分变量及切分点(即使得平方误差最小),ops用于控制函数停止机制 tolS = ops[0] # 容许的误差下降值 tolN = ops[1] # 切分的最小样本数 if len(set(dataSet[:,-1].T.tolist()[0])) == 1: return None, leafType(dataSet) # 若所有类别值相等,退出,此时无最好切分量 m,n = shape(dataSet) S = errType(dataSet) # 存储数据集的平方误差 bestS = inf bestIndex = 0 # 初始化切分变量 bestValue = 0 # 初始化切分点 for featIndex in range(n-1): # 循环特征数目,featIndex此时为切分变量 for splitVal in set(dataSet[:,featIndex]): # 循环数据集行数,splitVal此时为切分点 mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal) # 根据循环到的切分变量与切分点将数据分成两个区域 if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue # 若切分后的样本点小于最小样本数,退出此次循环,继续下一个循环 newS = errType(mat0) + errType(mat1)# 计算划分后数据集的平方误差 if newS < bestS: # 若新的平方误差更小,更新各个数据 bestIndex = featIndex bestValue = splitVal bestS = newS if (S - bestS) < tolS: # 若误差下降值小于容许的误差下降值,退出,此时无最好切分量 return None, leafType(dataSet) mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue) # 用新的切分变量与切分点将数据分成两个区域 if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): # 若切分后的样本点小于最小样本数,退出,此时无最好切分量 return None, leafType(dataSet) return bestIndex,bestValue def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)): # 该函数根据接收的数据集创建决策树(子树) feat, val = chooseBestSplit(dataSet, leafType, errType, ops) # 寻找对于数据集dataSet的最好切分变量及切分点 if feat == None: return val # 若无最好的切分点,则返回数据集均值作为叶节点 retTree = {} retTree['spInd'] = feat retTree['spVal'] = val lSet, rSet = binSplitDataSet(dataSet, feat, val) # 用最好的切分变量与切分点将数据分成两个区域,作为左右子树 retTree['left'] = createTree(lSet, leafType, errType, ops) retTree['right'] = createTree(rSet, leafType, errType, ops) return retTree
以上全部内容参考书籍如下:
李航《统计学习方法》