zoukankan      html  css  js  c++  java
  • 【AdaBoost算法】强分类器训练过程

    一、强分类器训练过程

    算法原理如下(参考自VIOLA P, JONES M. Robust real time object detection[A] . 8th IEEE International Conference on Computer Vision[C] . Vancouver , 2001.)

    • 给定样本 (x1; y1) , . . . , (xn; yn) ; 其中yi = 0表示负样本,yi =1表示正样本;
    • 初始化权重:负样本权重W0i= 1/2m, 正样本权重W1i = 1/ 2l,其中m为负样本总数,l为正样本总数;
    • 对于t = 1, ... T(T为训练次数):
      1. 权重归一化,简单说就是使本轮所有样本的权重的和为1;
      2. 根据每一个特征训练简单分类器,仅使用一个特征;
      3. 从所有简单分类器中选出一个分错率最低的分类器,为弱分类器;
      4. 更新权重
    • 最后组合T个弱分类器为强分类器

    二、代码实现及说明(python)

    目的:训练得到一个强分类器,该强分类器分错率低于预设值,且该强分类器由若干个弱分类器(对应单个特征)组成,通过若干个分类器及其权重计算得到的值对样本进行分类。

    def adaBoostTrainDS(dataArr,classLabels,numIt=40): 
        weakClassArr = [] #存放强分类器的所有弱分类器信息
        m = shape(dataArr)[0] 
        D = mat(ones((m,1))/m)   #权重初始化
        aggClassEst = mat(zeros((m,1)))
        for i in range(numIt):
            bestStump,error,classEst = buildStump(dataArr,classLabels,D)#根据训练样本、权重得到一个弱分类器
    
            print "D:",D.T
            alpha = float(0.5*log((1.0-error)/max(error,1e-16)))#计算alpha值,该值与分错率相关,分错率越小,该值越大,弱分类器权重
                                                                #max(error,1e-16)用于确保错误为0时不会发生除0溢出
            bestStump['alpha'] = alpha  
            weakClassArr.append(bestStump)  #存储该弱分类
            print "classEst: ",classEst.T
            expon = multiply(-1*alpha*mat(classLabels).T,classEst) 
            D = multiply(D,exp(expon))  #重新计算样本权重
            D = D/D.sum() #归一化
            #计算当前强分类器的分错率,达到预期要求即停止
            aggClassEst += alpha*classEst
            print "aggClassEst: ",aggClassEst.T
            aggErrors = multiply(sign(aggClassEst) != mat(classLabels).T,ones((m,1))) #计算数据点哪个是错误
            print 'aggErrors: ',sign(aggClassEst) != mat(classLabels).T
            print 'aggErrors: ',aggErrors
            errorRate = aggErrors.sum()/m #计算错误率
            print "total error: ",errorRate
            if errorRate == 0.0: break
        return weakClassArr

    三、运行结果

    训练样本:

        datMat = matrix([[ 1. ,  2.1,  0.3],
                                     [ 2. ,  1.1,  0.4],
                                     [ 1.3,  1. ,  1.2],
                                     [ 1. ,  1. ,  1.1],
                                     [ 2. ,  1. ,  1.3],
                                     [ 7. ,  2. ,  0.35]])
        classLabels = [1.0, 1.0, 1.0, -1.0, -1.0, -1.0]

    训练得到的强分类器(强分类器分错率:0%,单个弱分类器最小分错率为33%,在上一篇已经测试过):

    [{'dim': 0, 'ineq': 'gt', 'thresh': 1.6000000000000001, 'alpha': 0.34657359027997275},

    {'dim': 1, 'ineq': 'lt', 'thresh': 1.0, 'alpha': 0.5493061443340549},

    {'dim': 0, 'ineq': 'gt', 'thresh': 2.2000000000000002, 'alpha': 0.5493061443340549},

    {'dim': 2, 'ineq': 'gt', 'thresh': 0.29999999999999999, 'alpha': 0.4777557225137181},

    {'dim': 0, 'ineq': 'lt', 'thresh': 1.0, 'alpha': 0.49926441505556346}]

    手动计算分类:

    针对第一个样本[ 1. ,  2.1,  0.3],利用强分类器计算结果如下:
    - 0.34657359027997275

    - 0.5493061443340549

    - 0.5493061443340549

    + 0.4777557225137181

    + 0.49926441505556346

    = -0.468165741378801--->小于0,正样本

    针对第六个样本[ 7. ,  2. ,  0.35],利用强分类器计算结果如下:
    + 0.34657359027997275

    - 0.5493061443340549

    + 0.5493061443340549

    + 0.4777557225137181

    - 0.49926441505556346

    = +0.3250648977381274--->大于0,负样本

    其它样本的计算类似


    结论:

    强分类器分类,即通过若干个分类器的权重的正负号计算得出,而正负号是通过该若分类器的阈值判断得到;

    强分类器比弱分类器准确率高。

  • 相关阅读:
    Java基础00-循环语句7
    Java基础00-分支语句6
    Java基础00-数据输入5
    Java基础00-运算符4
    Java基础00-基础语法3
    Java基础00-第一个程序2
    第十四题
    第十三题
    第十二题
    第十题
  • 原文地址:https://www.cnblogs.com/chenpi/p/5128234.html
Copyright © 2011-2022 走看看