zoukankan      html  css  js  c++  java
  • 利用AdaBoost方法构建多个弱分类器进行分类

    1.AdaBoost 思想

    补充:这里的若分类器之间有比较强的依赖关系;对于若依赖关系的分类器一般使用Bagging的方法

    弱分类器是指分类效果要比随机猜测效果略好的分类器,我们可以通过构建多个弱分类器来进行最终抉择(俗话说,三个臭皮匠顶个诸葛亮大概就这意思)。首先我们给每个样例初始化一个权重,构成向量D,然后再更新D,更新规则如下:

    当一个样例被分类器正确分类时,我们就减小它的权重

    image

    否则,增大它的权重

    image

    对于每个弱分类器,我们根据它对样例分类错误率来设置它的权重alpha,分类错误率越高,相应的alpha就会越小,如下所示

    image

    最终我们训练出多个弱分类器,通过加权分类结果,输出最终分类结果,如下图所示

    image

    2.实验过程

      1 # -*- coding: utf-8 -*-
      2 """
      3 Created on Wed Mar 29 16:57:37 2017
      4 
      5 @author: MyHome
      6 """
      7 import  numpy as np
      8 
      9 '''返回分类结果向量'''
     10 def stumpClassify(dataMatrix,dimen,threshVal,threshIneq):
     11     retArray = np.ones((np.shape(dataMatrix)[0],1))
     12     if threshIneq == "lt":
     13         retArray[dataMatrix[:,dimen] <= threshVal] = -1.0
     14     else:
     15         retArray[dataMatrix[:,dimen] > threshVal] = -1.0
     16 
     17     return retArray
     18 
     19 '''构造一个最佳决策树,返回决策树字典'''
     20 def buildStump(dataArr,classLabels,D):
     21     dataMatrix = np.mat(dataArr)
     22     labelMat = np.mat(classLabels).T
     23     m,n = dataMatrix.shape
     24     numSteps = 10.0
     25     bestStump = {}
     26     bestClassEst = np.mat(np.zeros((m,1)))
     27     minError = np.inf
     28 
     29     for i in xrange(n):
     30         rangeMin = dataMatrix[:,i].min()
     31         rangeMax = dataMatrix[:,i].max()
     32         stepSize = (rangeMax - rangeMin)/numSteps
     33         for j in xrange(-1,int(numSteps)+1):
     34             for inequal in ["lt","gt"]:
     35                 threshVal = (rangeMin + float(j)*stepSize)
     36                 #print threshVal
     37                 predictedVals = stumpClassify(dataMatrix,i,threshVal,inequal)
     38                 errArr = np.mat(np.ones((m,1)))
     39                 errArr[predictedVals==labelMat] = 0
     40                 weightedError = D.T*errArr
     41 
     42                 if weightedError < minError:
     43                     minError = weightedError
     44                     bestClassEst = predictedVals.copy()
     45                     bestStump["dim"] = i
     46                     bestStump["thresh"] = threshVal
     47                     bestStump["ineq"] = inequal
     48 
     49     return bestStump,minError,bestClassEst
     50 
     51 '''训练多个单层决策树分类器,构成一个数组'''
     52 def adaBoostTrainDS(dataArr,classLabels,numIt =40):
     53     weakClassArr = []
     54     m = np.shape(dataArr)[0]
     55     D = np.mat(np.ones((m,1))/m)
     56     aggClassEst = np.mat(np.zeros((m,1)))
     57     for i in range(numIt):
     58         bestStump,error,classEst = buildStump(dataArr,classLabels,D)
     59         #print "D:",D.T
     60         alpha = float(0.5*np.log((1.0-error)/max(error,1e-16)))
     61         bestStump["alpha"] = alpha
     62         weakClassArr.append(bestStump)
     63         #print "ClassEst:",classEst.T.shape
     64         expon = np.multiply(-1*alpha*np.mat(classLabels).T,classEst)
     65         #print expon
     66         D = np.multiply(D,np.exp(expon))
     67         D = D / D.sum()
     68         aggClassEst += alpha*classEst
     69         #print "aggClassEst: ",aggClassEst.T
     70         aggErrors = np.multiply(np.sign(aggClassEst)!= np.mat(classLabels).T,np.ones((m,1)))
     71         errorRate = aggErrors.sum()/m
     72         print "total error:",errorRate,"
    "
     73         if errorRate ==0.0:
     74             break
     75     return weakClassArr
     76 
     77 
     78 '''分类器'''
     79 def adaClassify(datToClass,classifierArr):
     80     dataMatrix = np.mat(datToClass)
     81     m = np.shape(dataMatrix)[0]
     82     aggClassEst = np.mat(np.zeros((m,1)))
     83     for i in range(len(classifierArr)):
     84         classEst = stumpClassify(dataMatrix,classifierArr[i]["dim"],
     85         classifierArr[i]["thresh"],classifierArr[i]["ineq"])
     86         aggClassEst += classifierArr[i]["alpha"]*classEst
     87         #print aggClassEst
     88     return np.sign(aggClassEst)
     89 
     90 '''载入数据'''
     91 def loadDataSet(fileName):
     92     numFeat = len(open(fileName).readline().split("	"))
     93     dataMat = []
     94     labelMat = []
     95     fr = open(fileName)
     96     for line in fr.readlines():
     97         lineArr = []
     98         curLine = line.strip().split("	")
     99         for i in range(numFeat-1):
    100             lineArr.append(float(curLine[i]))
    101         dataMat.append(lineArr)
    102         labelMat.append(float(curLine[-1]))
    103     #print dataMat,labelMat
    104     return dataMat,labelMat
    105 
    106 if __name__ == "__main__":
    107     datArr,labelArr =  loadDataSet("horseColicTraining2.txt")
    108 
    109     classifierArray = adaBoostTrainDS(datArr,labelArr,10)
    110     testData,testY = loadDataSet("horseColicTest2.txt")
    111     predictionArr = adaClassify(testData,classifierArray)
    112     errorArr = np.mat(np.ones((len(testData),1)))
    113     FinalerrorRate = errorArr[predictionArr!= np.mat(testY).T].sum()/float(errorArr.shape[0])
    114     print "FinalerrorRate:",FinalerrorRate
    115 
    116 

    3.实验结果

    total error: 0.284280936455

    total error: 0.284280936455

    total error: 0.247491638796

    total error: 0.247491638796

    total error: 0.254180602007

    total error: 0.240802675585

    total error: 0.240802675585

    total error: 0.220735785953

    total error: 0.247491638796

    total error: 0.230769230769

    FinalerrorRate: 0.238805970149

    4.实验总结

    通过多个构建多个弱分类器,然后根据各个弱分类器的能力大小(即权重)来对分类结果进行加权求和,得出最终结果。只要数据集比较完整,这种方法还是很强大的,后续还可以尝试更多其他的分类器进行集成。

  • 相关阅读:
    学习Mybatis与mysql数据库的示例笔记
    SpringAOP学习笔记
    idea开发ssh(Spring+struts+Hibernate)实现对MySQL数据库的增删改查
    springmvc加vue实现前后端数据的跨域访问
    idea开发工具springmvc加vue.js实现MySQL数据库的查询操作
    利用idea开发工具实现ssh(spring+struts+hibernate)加vue.js前后台对数据库的查询
    appweb 7.0.2版本编译
    Unable to register the DLL/OCX: RegSvr32 failed with exit code 0x3 我的解决方法
    无法定位程序输入点 InitializeCriticalSectionEx 于动态链接库 Kernel32.dll 上 问题解决方法
    海思3516D + IMX291图像闪烁问题定位
  • 原文地址:https://www.cnblogs.com/lpworkstudyspace1992/p/6668990.html
Copyright © 2011-2022 走看看