zoukankan      html  css  js  c++  java
  • 机器学习(四):决策树

    七、代码实现(python)

    以下代码来自Peter Harrington《Machine Learing in Action》
    本例代码实现算法5,生成最小二乘回归树。
    代码如下(保存为CART.py):

    # -- coding: utf-8 --
    from numpy import *
    
    def loadDataSet(fileName):
        # 获取训练集
        dataMat = []
        fr = open(fileName)
        for line in fr.readlines():
            curLine = line.strip().split('	')
            fltLine = map(float,curLine)
            dataMat.append(fltLine)
        return dataMat
    
    def binSplitDataSet(dataSet, feature, value):
        # 该函数接收3个参数,数据集、第几个特征(切分变量)、划分条件(切分点),根据选择的特征和划分条件将数据分成两个区域
        mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0],:][0]
        mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0],:][0]
        return mat0,mat1
    
    def regLeaf(dataSet):
        # 获取数据集dataSet最后一列的平均值
        return mean(dataSet[:,-1])
    
    def regErr(dataSet):
        # 根据式(4)计算数据集dataSet的平方误差
        # var用于计算方差
        return var(dataSet[:,-1]) * shape(dataSet)[0]
    
    def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
        # 该函数用于寻找对于数据集dataSet的最好切分变量及切分点(即使得平方误差最小),ops用于控制函数停止机制
        tolS = ops[0]                              # 容许的误差下降值
        tolN = ops[1]                              # 切分的最小样本数
        if len(set(dataSet[:,-1].T.tolist()[0])) == 1:
            return None, leafType(dataSet)         # 若所有类别值相等,退出,此时无最好切分量
        m,n = shape(dataSet)
        S = errType(dataSet)                       # 存储数据集的平方误差
        bestS = inf
        bestIndex = 0                              # 初始化切分变量
        bestValue = 0                              # 初始化切分点
        for featIndex in range(n-1):
            # 循环特征数目,featIndex此时为切分变量
            for splitVal in set(dataSet[:,featIndex]):
                # 循环数据集行数,splitVal此时为切分点
                mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)      # 根据循环到的切分变量与切分点将数据分成两个区域
                if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue # 若切分后的样本点小于最小样本数,退出此次循环,继续下一个循环
                newS = errType(mat0) + errType(mat1)# 计算划分后数据集的平方误差
                if newS < bestS:                    # 若新的平方误差更小,更新各个数据
                    bestIndex = featIndex
                    bestValue = splitVal
                    bestS = newS
        if (S - bestS) < tolS:                      # 若误差下降值小于容许的误差下降值,退出,此时无最好切分量
            return None, leafType(dataSet)
        mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)              # 用新的切分变量与切分点将数据分成两个区域
        if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):                   # 若切分后的样本点小于最小样本数,退出,此时无最好切分量
            return None, leafType(dataSet)
        return bestIndex,bestValue
    
    def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
        # 该函数根据接收的数据集创建决策树(子树)
        feat, val = chooseBestSplit(dataSet, leafType, errType, ops)             # 寻找对于数据集dataSet的最好切分变量及切分点
        if feat == None: return val                 # 若无最好的切分点,则返回数据集均值作为叶节点
        retTree = {}
        retTree['spInd'] = feat
        retTree['spVal'] = val
        lSet, rSet = binSplitDataSet(dataSet, feat, val)                         # 用最好的切分变量与切分点将数据分成两个区域,作为左右子树
        retTree['left'] = createTree(lSet, leafType, errType, ops)
        retTree['right'] = createTree(rSet, leafType, errType, ops)
        return retTree
    

     

    以上全部内容参考书籍如下:
    李航《统计学习方法》

  • 相关阅读:
    2.9数据-paddlepaddle数据集wmt16
    2.8数据-paddlepaddle数据集uci_housing
    2.6数据-paddlepaddle数据集movielens
    2.5数据-paddlepaddle数据集imikolov
    2.4数据-paddlepaddle数据集imdb
    2.3数据-paddlepaddle数据集Conll05
    在android程序中怎么执行ifconfig命令来修改android 的ip地址,
    VMware 11安装Mac OS X 10.10 及安装Mac Vmware Tools.
    xcode7 如何真机测试
    海子
  • 原文地址:https://www.cnblogs.com/pengfeiz/p/11392684.html
Copyright © 2011-2022 走看看