zoukankan      html  css  js  c++  java
  • 连续值的CART(分类回归树)原理和实现

    上一篇我们学习和实现了CART(分类回归树),不过主要是针对离散值的分类实现,下面我们来看下连续值的cart分类树如何实现

    思考连续值和离散值的不同之处:

    二分子树的时候不同:离散值需要求出最优的两个组合,连续值需要找到一个合适的分割点把特征切分为前后两块

    这里不考虑特征的减少问题

    切分数据的不同:根据大于和小于等于切分数据集

    def splitDataSet(dataSet, axis, value,threshold):
        retDataSet = []
        if threshold == 'lt':
            for featVec in dataSet:
                if featVec[axis] <= value:
                    retDataSet.append(featVec)
        else:
            for featVec in dataSet:
                if featVec[axis] > value:
                    retDataSet.append(featVec)
    
        return retDataSet

    选择最好特征的最好特征值

    def chooseBestFeatureToSplit(dataSet):
        numFeatures = len(dataSet[0]) - 1      
        bestGiniGain = 1.0; bestFeature = -1;bsetValue=""
        for i in range(numFeatures):        #遍历特征
            featList = [example[i] for example in dataSet]#得到特征列
            uniqueVals = list(set(featList))       #从特征列获取该特征的特征值的set集合
            uniqueVals.sort()
            for value in uniqueVals:# 遍历所有的特征值
                GiniGain = 0.0
                # 左增益
                left_subDataSet = splitDataSet(dataSet, i, value,'lt')
                left_prob = len(left_subDataSet)/float(len(dataSet))
                GiniGain += left_prob * calGini(left_subDataSet)
                # print left_prob,calGini(left_subDataSet),
                # 右增益
                right_subDataSet = splitDataSet(dataSet, i, value,'gt')
                right_prob = len(right_subDataSet)/float(len(dataSet))
                GiniGain += right_prob * calGini(right_subDataSet)
                # print right_prob,calGini(right_subDataSet),         
                # print GiniGain
                if (GiniGain < bestGiniGain):       #比较是否是最好的结果
                    bestGiniGain = GiniGain         #记录最好的结果和最好的特征
                    bestFeature = i
                    bsetValue=value
        return bestFeature,bsetValue

    生成cart:总体上和离散值的差不多,主要差别在于分支的值要加上大于或者小于等于号

    def createTree(dataSet,labels):
        classList = [example[-1] for example in dataSet]
        # print dataSet
        if classList.count(classList[0]) == len(classList): 
            return classList[0]#所有的类别都一样,就不用再划分了
        if len(dataSet) == 1: #如果没有继续可以划分的特征,就多数表决决定分支的类别
            return majorityCnt(classList)
        bestFeat,bsetValue = chooseBestFeatureToSplit(dataSet)
        # print bestFeat,bsetValue,labels
        bestFeatLabel = labels[bestFeat]
        if bestFeat==-1:
            return majorityCnt(classList)
        myTree = {bestFeatLabel:{}}
        featValues = [example[bestFeat] for example in dataSet]
        uniqueVals = list(set(featValues))
        subLabels = labels[:]
        # print bsetValue
        myTree[bestFeatLabel][bestFeatLabel+'<='+str(round(float(bsetValue),3))] = createTree(splitDataSet(dataSet, bestFeat, bsetValue,'lt'),subLabels)
        myTree[bestFeatLabel][bestFeatLabel+'>'+str(round(float(bsetValue),3))] = createTree(splitDataSet(dataSet, bestFeat, bsetValue,'gt'),subLabels)
        return myTree  

    我们看下连续值的cart大概是什么样的(数据集是我们之前用的100个点的数据集)

  • 相关阅读:
    30岁前不必在乎的28件事
    开发工作的枯燥与快乐
    Ajax技术的初步认识二(转)
    算法实习一
    判断程序是否运行的代码 (有点单例模式的味道)
    SQL2008+SERVER2008的解决方案
    天秤座的爱情(转)
    VB sytem32\ieframe.dll找不到的解决方案
    回忆有感《清明雨上》
    Ajax技术的初步认识一(转)
  • 原文地址:https://www.cnblogs.com/qwj-sysu/p/5981231.html
Copyright © 2011-2022 走看看