由于对AdaBoost算法的弱分类器不是很了解,没明白算法描述里的“在权值分布的训练集上,取阈值使得分类误差率最小,然后就得到基本分类器”这句话。不是很明白怎么根据权值分布得到的阈值?提供的代码是直接给出了弱分类器,不知道是不是这样,有问题请提出,谢谢。一起学习。由于被弱分类器搞的郁闷,所以代码中没有添加注释,但是步骤1.2.3是根据李航的算法描述1.2.3进行编写的。
1 import java.util.ArrayList; 2 import java.util.List; 3 4 public class AdaBoost { 5 public static void main(String[] args){ 6 TestPoint[] testpoint = new TestPoint[10]; 7 testpoint[0] = new TestPoint(0,1); 8 testpoint[1] = new TestPoint(1,1); 9 testpoint[2] = new TestPoint(2,1); 10 testpoint[3] = new TestPoint(3,-1); 11 testpoint[4] = new TestPoint(4,-1); 12 testpoint[5] = new TestPoint(5,-1); 13 testpoint[6] = new TestPoint(6,1); 14 testpoint[7] = new TestPoint(7,1); 15 testpoint[8] = new TestPoint(8,1); 16 testpoint[9] = new TestPoint(9,-1); 17 18 List<List<Integer>> G = new ArrayList<List<Integer>>(); 19 double[] v = {2.5,8.5,5.5}; 20 double[] D = new double[testpoint.length]; 21 List<Double> A = new ArrayList<Double>(); 22 23 D = first(D,testpoint); 24 for(int i=0;i<v.length;i++){ 25 second(testpoint,D,v[i],G,A,i); 26 } 27 third(A,G); 28 } 29 30 private static void third(List<Double> a, List<List<Integer>> g) { 31 System.out.print("所得函数: sign["); 32 for(int i=0;i<a.size();i++){ 33 System.out.print(a.get(i) + " * " + "g[" + i + "]"); 34 if(i<a.size()-1){ 35 System.out.print(" + "); 36 } 37 } 38 System.out.println("]"); 39 } 40 41 private static List<List<Integer>> second(TestPoint[] testpoint, double[] D, double v, List<List<Integer>> G, List<Double> A,int index) { 42 double Z = 0; 43 double error = 0.0; 44 double a = 0; 45 46 int[] GTemp = new int[testpoint.length]; 47 List<Integer> LTemp = new ArrayList<Integer>(); 48 49 for(int i=0;i<testpoint.length;i++){ 50 if(v != 5.5){ 51 if(testpoint[i].getX() < v){ 52 GTemp[i] = 1; 53 } 54 else{ 55 GTemp[i] = -1; 56 } 57 } 58 else{ 59 if(testpoint[i].getX() < v){ 60 GTemp[i] = -1; 61 } 62 else{ 63 GTemp[i] = 1; 64 } 65 } 66 } 67 68 for(int i=0;i<GTemp.length;i++){ 69 LTemp.add(GTemp[i]); 70 } 71 G.add(LTemp); 72 73 for(int i=0;i<testpoint.length;i++){ 74 if(testpoint[i].getY() != GTemp[i]){ 75 error += D[i] * 1; 76 } 77 } 78 79 System.out.println(" 错误率e" + (index + 1) + ":" +error); 80 81 a = 0.5 * Math.log((1-error)/error); 82 A.add(a); 83 84 for(int i=0;i<testpoint.length;i++){ 85 Z += D[i] * Math.exp((-a) * testpoint[i].getY() * GTemp[i]); 86 } 87 88 System.out.print("权值分布D" + (index +1) + ":" ); 89 for(int i=0;i<testpoint.length;i++){ 90 D[i] = (D[i]/Z) * Math.exp((-a) * testpoint[i].getY() * GTemp[i]); 91 System.out.print(D[i] + " "); 92 } 93 System.out.println(); 94 95 return G; 96 } 97 98 private static double[] first(double[] D, TestPoint[] testpoint) { 99 for(int i=0;i<testpoint.length;i++){ 100 D[i] = 1.0/testpoint.length; 101 } 102 return D; 103 } 104 }
1 //训练数据点,x为数据,y为类别 2 public class TestPoint { 3 private double x; 4 private double y; 5 6 public TestPoint(double i,double y){ 7 this.x = i; 8 this.y = y; 9 } 10 11 public double getX() { 12 return x; 13 } 14 15 public double getY() { 16 return y; 17 } 18 public void setX(int x) { 19 this.x = x; 20 } 21 22 public void setY(int y) { 23 this.y = y; 24 } 25 26 public String toString(double x ,double y){ 27 return x + " " + y; 28 } 29 }
运行结果:
错误率e1:0.30000000000000004
权值分布D1:0.07142857142857142 0.07142857142857142 0.07142857142857142 0.07142857142857142 0.07142857142857142 0.07142857142857142 0.16666666666666663 0.16666666666666663 0.16666666666666663 0.07142857142857142
错误率e2:0.21428571428571427
权值分布D2:0.04545454545454546 0.04545454545454546 0.04545454545454546 0.1666666666666667 0.1666666666666667 0.1666666666666667 0.10606060606060605 0.10606060606060605 0.10606060606060605 0.04545454545454546
错误率e3:0.18181818181818185
权值分布D3:0.12499999999999997 0.12499999999999997 0.12499999999999997 0.10185185185185185 0.10185185185185185 0.10185185185185185 0.0648148148148148 0.0648148148148148 0.0648148148148148 0.12499999999999997
所得函数: sign[0.4236489301936017 * g[0] + 0.6496414920651304 * g[1] + 0.752038698388137 * g[2]]