zoukankan      html  css  js  c++  java
  • 《机器学习实战》笔记——AdaBoost

    笔记见备注

     1 # _*_ coding:utf-8 _*_
     2 from numpy import *
     3 # 简单数据集
     4 def loadSimpData():
     5     datMat = matrix([[1., 2.1],
     6                      [2., 1.1],
     7                      [1.3, 1.],
     8                      [1., 1.],
     9                      [2., 1.]])
    10 
    11     classLabels = [1.0, 1.0, -1.0, -1.0, 1.0]
    12     return datMat, classLabels
    13 
    14 # 7-1 单层决策树生成函数
    15 
    16 # lt=less than
    17 # 分类器的构建(单纯地将某一特征上的所有取值与输入的阈值进行比较,若制定lt为负,则特征值小于阈值的样本被标记为-1)
    18 # 相反而知,若指定gt为负,则特征值大于阈值的样本被标记为-1
    19 def stumpClassify(dataMatrix, dimen, threshVal, threshIneq):
    20     retArray = ones((shape(dataMatrix)[0],1))
    21     if threshIneq == 'lt':
    22         retArray[dataMatrix[:, dimen] <= threshVal] = -1.0
    23     else:
    24         retArray[dataMatrix[:, dimen] > threshVal] = -1.0
    25     return retArray
    26 
    27 # stumpClassify分类器的预测值收到了特征、阈值和阈值两边到底哪边为正标签哪边为父标签的影响
    28 # 所以有三重循环
    29 # 第一重:遍历每个特征
    30 # 第二重:对每个特征上依次设定不同的阈值
    31 # 第三重:每个特征的每个阈值设定以后 还要依次以小、大于阈值作为依据调用分类器。得出预测结果。将结果与真实结果对比,得出错误向量
    32 # 通过错误向量得出加权错误值之后与当前的最小错误值进行对比,迭代后得到最终的最小错误
    33 def buildStump(dataArr, classLabels, D):
    34     dataMatrix = mat(dataArr)
    35     labelMat = mat(classLabels).T
    36     m,n = shape(dataMatrix)
    37     numSteps = 10.0 # 也可以变大,使得阈值的精确度更高,但是会造成计算量的增大
    38     bestStump = {}
    39     bestClasEst = mat(zeros((m,1)))
    40     minError = inf
    41     for i in range(n):
    42         rangeMin = dataMatrix[:,i].min()
    43         rangeMax = dataMatrix[:,i].max()
    44         stepSize = (rangeMax - rangeMin)/numSteps
    45         for j in range(-1, int(numSteps)+1):    # 为什么从-1开始????
    46             for inequal in ['lt','gt']:
    47                 threshVal = (rangeMin + float(j)*stepSize)
    48                 predictedVals = stumpClassify(dataMatrix, i, threshVal, inequal)
    49                 errArr = mat(ones((m, 1)))
    50                 errArr[predictedVals == labelMat] = 0
    51                 weightedError = D.T * errArr
    52                 # print ("split: dim %d, thresh %.2f, thresh ineqal: %s, the weighted error is %.3f"
    53                 # %(i, threshVal, inequal, weightedError))
    54                 if weightedError < minError:
    55                     minError = weightedError
    56                     bestClasEst = predictedVals.copy()
    57                     bestStump['dim'] = i
    58                     bestStump['thresh'] = threshVal
    59                     bestStump['ineq'] = inequal
    60     return bestStump, minError, bestClasEst # 单层决策树建立出来之后,需要知道的是最好的决策树(是的预测值的加权错误最小的树)的、
    61     # 特征、阈值、约定的负标签方向;最小错误值;最好的预测结果(一个以样本个数为维度的向量)
    62 
    63 
    64 
    65 
    66 # 7-2 基于单层决策树的AdaBoost训练过程
    67 def adaBoostTrainDS(dataArr, classLabels, numIt=40):
    68     weakClassArr = []
    69     m = shape(dataArr)[0]
    70     D = mat(ones((m,1))/m)
    71     aggClassEst = mat(zeros((m,1)))
    72     for i in range(numIt):
    73         bestStump, error, classEst = buildStump(dataArr, classLabels, D)
    74         print ("D:", D.T)
    75         alpha = float(0.5*log((1.0-error)/max(error,1e-16)))    # 避免没有错误是发生溢出
    76         bestStump['alpha'] = alpha
    77         weakClassArr.append(bestStump)
    78         print ("classEst: ", classEst.T)
    79         expon = multiply(-1*alpha*mat(classLabels).T, classEst)
    80         D = multiply(D, exp(expon))
    81         D = D/D.sum()
    82         aggClassEst += alpha*classEst
    83         print ("aggClassEst: ",aggClassEst.T)
    84         aggErrors = multiply(sign(aggClassEst) != mat(classLabels).T, ones((m,1)))  #ones zeros的()里面一定不能忘了是元祖,而不是两个数
    85         errorRate = aggErrors.sum()/m
    86         print ("total error: ", errorRate, "
    ")
    87         if errorRate == 0.0: break
    88     return weakClassArr
  • 相关阅读:
    一个具体的例子学习Java volatile关键字
    JavaScript实现的水果忍者游戏,支持鼠标操作
    记录我开发工作中遇到HTTP跨域和OPTION请求的一个坑
    微信程序开发系列教程(四)使用微信API创建公众号自定义菜单
    微信程序开发系列教程(三)使用微信API给微信用户发文本消息
    Java实现 LeetCode 547 朋友圈(并查集?)
    Java实现 LeetCode 547 朋友圈(并查集?)
    Java实现 LeetCode 547 朋友圈(并查集?)
    Java实现 LeetCode 546 移除盒子(递归,vivo秋招)
    Java实现 LeetCode 546 移除盒子(递归,vivo秋招)
  • 原文地址:https://www.cnblogs.com/DianeSoHungry/p/7087979.html
Copyright © 2011-2022 走看看