理论知识可参考:《统计学习方法》 (李航 著) 第八章
简单代码实现:
1 from numpy import * 2 import matplotlib.pyplot as plt 3 4 def loadSimpData(): 5 dataMat = matrix([[1,2.1], 6 [2,1.1], 7 [1.3,1], 8 [1,1], 9 [2,1]]) 10 classLabels = [1.0, 1.0, -1.0, -1.0, 1.0] 11 return dataMat, classLabels 12 13 def stumpClassify(dataMatrix, dimen, threshVal, threshIneq): 14 retArray = ones((shape(dataMatrix)[0], 1)) 15 if threshIneq == 'lt': 16 retArray[dataMatrix[:, dimen] <= threshVal] = -1.0 17 else: 18 retArray[dataMatrix[:, dimen] > threshVal] = -1.0 19 return retArray 20 21 def buildStump(dataArr, classLabels, D): 22 dataMatrix = mat(dataArr); labelMat = mat(classLabels).T 23 m, n = shape(dataMatrix) 24 numSteps = 10.0; bestStump = {}; bestClasEst = mat(zeros((m, 1))) 25 minError = inf 26 for i in range(n): 27 rangeMin = dataMatrix[:, i].min(); rangeMax = dataMatrix[:,i].max(); 28 stepSize = (rangeMax-rangeMin)/numSteps 29 for j in range(-1, int(numSteps)+1): 30 for inequal in ['lt', 'gt']: 31 threshVal = (rangeMin + float(j) * stepSize) 32 predictedVals = stumpClassify(dataMatrix, i, threshVal, inequal) 33 errArr = mat(ones((m, 1))) 34 errArr[predictedVals == labelMat] = 0 35 weightedError = D.T * errArr 36 #print("split: dim %d, thresh %.2f, thresh ineqal: %s, the weighted error is %.3f" % (i, threshVal, inequal, weightedError)) 37 if weightedError < minError: 38 minError = weightedError 39 bestClasEst = predictedVals.copy() 40 bestStump['dim'] = i 41 bestStump['thresh'] = threshVal 42 bestStump['ineq'] = inequal 43 return bestStump, minError, bestClasEst 44 45 def adaBoostTrainDS(dataArr, classLabels, numIt = 40): 46 weakClassArr = [] 47 m = shape(dataArr)[0] 48 D = mat(ones((m, 1))/m) 49 aggClassEst = mat(zeros((m, 1))) 50 for i in range(numIt): 51 bestStump, error, classEst = buildStump(dataArr, classLabels, D) 52 print("D:", D.T) 53 alpha = float(0.5*log((1.0-error)/max(error, 1e-16))) 54 bestStump['alpha'] = alpha 55 weakClassArr.append(bestStump) 56 print("classEst:", classEst) 57 expon = multiply(-1*alpha*mat(classLabels).T, classEst) 58 D = multiply(D, exp(expon)) 59 D = D/D.sum() 60 aggClassEst += alpha*classEst 61 print("aggClassEst:", aggClassEst.T) 62 aggErrors = multiply(sign(aggClassEst) != mat(classLabels).T, ones((m, 1))) 63 errorRate = aggErrors.sum()/m 64 print("total error:", errorRate, " ") 65 if errorRate == 0.0: break 66 return weakClassArr 67 68 dataMat, classLabels = loadSimpData() 69 D = mat(ones((5, 1))/5) 70 classifierArray = adaBoostTrainDS(dataMat, classLabels, 9) 71 print(classifierArray)