zoukankan      html  css  js  c++  java
  • 连续值的CART(分类回归树)原理和实现

    上一篇我们学习和实现了CART(分类回归树),不过主要是针对离散值的分类实现,下面我们来看下连续值的cart分类树如何实现

    思考连续值和离散值的不同之处:

    二分子树的时候不同:离散值需要求出最优的两个组合,连续值需要找到一个合适的分割点把特征切分为前后两块

    这里不考虑特征的减少问题

    切分数据的不同:根据大于和小于等于切分数据集

    def splitDataSet(dataSet, axis, value,threshold):
        retDataSet = []
        if threshold == 'lt':
            for featVec in dataSet:
                if featVec[axis] <= value:
                    retDataSet.append(featVec)
        else:
            for featVec in dataSet:
                if featVec[axis] > value:
                    retDataSet.append(featVec)
    
        return retDataSet

    选择最好特征的最好特征值

    def chooseBestFeatureToSplit(dataSet):
        numFeatures = len(dataSet[0]) - 1      
        bestGiniGain = 1.0; bestFeature = -1;bsetValue=""
        for i in range(numFeatures):        #遍历特征
            featList = [example[i] for example in dataSet]#得到特征列
            uniqueVals = list(set(featList))       #从特征列获取该特征的特征值的set集合
            uniqueVals.sort()
            for value in uniqueVals:# 遍历所有的特征值
                GiniGain = 0.0
                # 左增益
                left_subDataSet = splitDataSet(dataSet, i, value,'lt')
                left_prob = len(left_subDataSet)/float(len(dataSet))
                GiniGain += left_prob * calGini(left_subDataSet)
                # print left_prob,calGini(left_subDataSet),
                # 右增益
                right_subDataSet = splitDataSet(dataSet, i, value,'gt')
                right_prob = len(right_subDataSet)/float(len(dataSet))
                GiniGain += right_prob * calGini(right_subDataSet)
                # print right_prob,calGini(right_subDataSet),         
                # print GiniGain
                if (GiniGain < bestGiniGain):       #比较是否是最好的结果
                    bestGiniGain = GiniGain         #记录最好的结果和最好的特征
                    bestFeature = i
                    bsetValue=value
        return bestFeature,bsetValue

    生成cart:总体上和离散值的差不多,主要差别在于分支的值要加上大于或者小于等于号

    def createTree(dataSet,labels):
        classList = [example[-1] for example in dataSet]
        # print dataSet
        if classList.count(classList[0]) == len(classList): 
            return classList[0]#所有的类别都一样,就不用再划分了
        if len(dataSet) == 1: #如果没有继续可以划分的特征,就多数表决决定分支的类别
            return majorityCnt(classList)
        bestFeat,bsetValue = chooseBestFeatureToSplit(dataSet)
        # print bestFeat,bsetValue,labels
        bestFeatLabel = labels[bestFeat]
        if bestFeat==-1:
            return majorityCnt(classList)
        myTree = {bestFeatLabel:{}}
        featValues = [example[bestFeat] for example in dataSet]
        uniqueVals = list(set(featValues))
        subLabels = labels[:]
        # print bsetValue
        myTree[bestFeatLabel][bestFeatLabel+'<='+str(round(float(bsetValue),3))] = createTree(splitDataSet(dataSet, bestFeat, bsetValue,'lt'),subLabels)
        myTree[bestFeatLabel][bestFeatLabel+'>'+str(round(float(bsetValue),3))] = createTree(splitDataSet(dataSet, bestFeat, bsetValue,'gt'),subLabels)
        return myTree  

    我们看下连续值的cart大概是什么样的(数据集是我们之前用的100个点的数据集)

  • 相关阅读:
    客户端用不用bind的区别
    软件项目开发流程
    很强大的几个网站
    MIS系统的 5个基本模块
    asp.net 生成导出word表单 ,导出excel; dataTable生成xls文件,返回前台下载;asp.net启动excel错误 80070005;excelxls columnName 不能改变; 读写excel的开源利器NPOI; 设置excel Cell的数据类型;
    正则快速入门
    asp.net 服务端调用客户端脚本; asp.net 服务端将文件传给客户端; reponse.ContentType的取值;用OutputStream.Write返回文件,效率是WriteFile的10倍;download link click和OutputStream的比较;
    常用编辑
    数据库流程图设计
    scroll位置控制 window的和div的
  • 原文地址:https://www.cnblogs.com/qwj-sysu/p/5981231.html
Copyright © 2011-2022 走看看