zoukankan      html  css  js  c++  java
  • AdaBoost算法

    理论知识可参考:《统计学习方法》 (李航 著) 第八章

    简单代码实现:

     1 from numpy import *
     2 import matplotlib.pyplot as plt
     3 
     4 def loadSimpData():
     5     dataMat = matrix([[1,2.1],
     6         [2,1.1],
     7         [1.3,1],
     8         [1,1],
     9         [2,1]])
    10     classLabels = [1.0, 1.0, -1.0, -1.0, 1.0]
    11     return dataMat, classLabels
    12 
    13 def stumpClassify(dataMatrix, dimen, threshVal, threshIneq):
    14     retArray = ones((shape(dataMatrix)[0], 1))
    15     if threshIneq == 'lt':
    16         retArray[dataMatrix[:, dimen] <= threshVal] = -1.0
    17     else:
    18         retArray[dataMatrix[:, dimen] > threshVal] = -1.0
    19     return retArray
    20 
    21 def buildStump(dataArr, classLabels, D):
    22     dataMatrix = mat(dataArr); labelMat = mat(classLabels).T
    23     m, n = shape(dataMatrix)
    24     numSteps = 10.0; bestStump = {}; bestClasEst = mat(zeros((m, 1)))
    25     minError = inf
    26     for i in range(n):
    27         rangeMin = dataMatrix[:, i].min(); rangeMax = dataMatrix[:,i].max();
    28         stepSize = (rangeMax-rangeMin)/numSteps
    29         for j in range(-1, int(numSteps)+1):
    30             for inequal in ['lt', 'gt']:
    31                 threshVal = (rangeMin + float(j) * stepSize)
    32                 predictedVals = stumpClassify(dataMatrix, i, threshVal, inequal)
    33                 errArr = mat(ones((m, 1)))
    34                 errArr[predictedVals == labelMat] = 0
    35                 weightedError = D.T * errArr
    36                 #print("split: dim %d, thresh %.2f, thresh ineqal: %s, the weighted error is %.3f" % (i, threshVal, inequal, weightedError))
    37                 if weightedError < minError:
    38                     minError = weightedError
    39                     bestClasEst = predictedVals.copy()
    40                     bestStump['dim'] = i
    41                     bestStump['thresh'] = threshVal
    42                     bestStump['ineq'] = inequal
    43     return bestStump, minError, bestClasEst
    44 
    45 def adaBoostTrainDS(dataArr, classLabels, numIt = 40):
    46     weakClassArr = []
    47     m = shape(dataArr)[0]
    48     D = mat(ones((m, 1))/m)
    49     aggClassEst = mat(zeros((m, 1)))
    50     for i in range(numIt):
    51         bestStump, error, classEst = buildStump(dataArr, classLabels, D)
    52         print("D:", D.T)
    53         alpha = float(0.5*log((1.0-error)/max(error, 1e-16)))
    54         bestStump['alpha'] = alpha
    55         weakClassArr.append(bestStump)
    56         print("classEst:", classEst)
    57         expon = multiply(-1*alpha*mat(classLabels).T, classEst)
    58         D = multiply(D, exp(expon))
    59         D = D/D.sum()
    60         aggClassEst += alpha*classEst
    61         print("aggClassEst:", aggClassEst.T)
    62         aggErrors = multiply(sign(aggClassEst) != mat(classLabels).T, ones((m, 1)))
    63         errorRate = aggErrors.sum()/m
    64         print("total error:", errorRate, "
    ")
    65         if errorRate == 0.0: break
    66     return weakClassArr
    67 
    68 dataMat, classLabels = loadSimpData()
    69 D = mat(ones((5, 1))/5)
    70 classifierArray = adaBoostTrainDS(dataMat, classLabels, 9)
    71 print(classifierArray)
    View Code
  • 相关阅读:
    maven中没找到settings.xml文件怎么办,简单粗暴
    如何修改新建后的maven的jdk版本号,简单粗暴
    如何修改maven下载的jar包存放位置,简单粗暴方法
    Kafka 温故(一):Kafka背景及架构介绍
    八、Kafka总结
    七、Kafka 用户日志上报实时统计之编码实践
    六、Kafka 用户日志上报实时统计之分析与设计
    五、Kafka 用户日志上报实时统计之 应用概述
    四、Kafka 核心源码剖析
    三、消息处理过程与集群维护
  • 原文地址:https://www.cnblogs.com/JustForCS/p/5289146.html
Copyright © 2011-2022 走看看