zoukankan      html  css  js  c++  java
  • java AdaBoost算法

    由于对AdaBoost算法的弱分类器不是很了解,没明白算法描述里的“在权值分布的训练集上,取阈值使得分类误差率最小,然后就得到基本分类器”这句话。不是很明白怎么根据权值分布得到的阈值?提供的代码是直接给出了弱分类器,不知道是不是这样,有问题请提出,谢谢。一起学习。由于被弱分类器搞的郁闷,所以代码中没有添加注释,但是步骤1.2.3是根据李航的算法描述1.2.3进行编写的。

      1 import java.util.ArrayList;
      2 import java.util.List;
      3 
      4 public class AdaBoost {
      5     public static void main(String[] args){
      6         TestPoint[] testpoint = new TestPoint[10];
      7         testpoint[0] = new TestPoint(0,1);
      8         testpoint[1] = new TestPoint(1,1);
      9         testpoint[2] = new TestPoint(2,1);
     10         testpoint[3] = new TestPoint(3,-1);
     11         testpoint[4] = new TestPoint(4,-1);
     12         testpoint[5] = new TestPoint(5,-1);
     13         testpoint[6] = new TestPoint(6,1);
     14         testpoint[7] = new TestPoint(7,1);
     15         testpoint[8] = new TestPoint(8,1);
     16         testpoint[9] = new TestPoint(9,-1);
     17         
     18         List<List<Integer>> G = new ArrayList<List<Integer>>();
     19         double[] v = {2.5,8.5,5.5};
     20         double[] D = new double[testpoint.length];
     21         List<Double> A = new ArrayList<Double>();
     22         
     23         D = first(D,testpoint);
     24         for(int i=0;i<v.length;i++){
     25             second(testpoint,D,v[i],G,A,i);
     26         }
     27         third(A,G);
     28     }
     29 
     30     private static void third(List<Double> a, List<List<Integer>> g) {
     31         System.out.print("所得函数:    sign[");
     32         for(int i=0;i<a.size();i++){
     33             System.out.print(a.get(i) + " * " + "g[" + i + "]");
     34             if(i<a.size()-1){
     35                 System.out.print(" + ");
     36             }
     37         }
     38         System.out.println("]");
     39     }
     40 
     41     private static List<List<Integer>> second(TestPoint[] testpoint, double[] D, double v, List<List<Integer>> G, List<Double> A,int index) {
     42         double Z = 0;
     43         double error = 0.0;
     44         double a = 0;
     45         
     46         int[] GTemp = new int[testpoint.length];
     47         List<Integer> LTemp = new ArrayList<Integer>();
     48         
     49         for(int i=0;i<testpoint.length;i++){
     50             if(v != 5.5){    
     51                 if(testpoint[i].getX() < v){
     52                     GTemp[i] = 1;
     53                 }
     54                 else{
     55                     GTemp[i] = -1;
     56                 }
     57             }
     58             else{
     59                 if(testpoint[i].getX() < v){
     60                     GTemp[i] = -1;
     61                 }
     62                 else{
     63                     GTemp[i] = 1;
     64                 }
     65             }
     66         }
     67 
     68         for(int i=0;i<GTemp.length;i++){
     69             LTemp.add(GTemp[i]);
     70         }
     71         G.add(LTemp);
     72         
     73         for(int i=0;i<testpoint.length;i++){
     74             if(testpoint[i].getY() != GTemp[i]){
     75                 error += D[i] * 1;
     76             }
     77         }
     78         
     79         System.out.println("         错误率e" + (index + 1) + ":" +error);
     80         
     81         a = 0.5 * Math.log((1-error)/error);
     82         A.add(a);
     83         
     84         for(int i=0;i<testpoint.length;i++){
     85             Z += D[i] * Math.exp((-a) * testpoint[i].getY() * GTemp[i]);
     86         }
     87         
     88         System.out.print("权值分布D" + (index +1) + ":" ); 
     89         for(int i=0;i<testpoint.length;i++){
     90             D[i] = (D[i]/Z) * Math.exp((-a) * testpoint[i].getY() * GTemp[i]);
     91             System.out.print(D[i] + "  ");
     92         }
     93         System.out.println();
     94         
     95         return G;
     96     }
     97     
     98     private static double[] first(double[] D, TestPoint[] testpoint) {
     99         for(int i=0;i<testpoint.length;i++){
    100             D[i] = 1.0/testpoint.length;
    101         }
    102         return D;
    103     }
    104 }
     1 //训练数据点,x为数据,y为类别
     2 public class TestPoint {
     3     private double x;
     4     private double y;
     5 
     6     public TestPoint(double i,double y){
     7         this.x = i;
     8         this.y = y;
     9     }
    10 
    11     public double getX() {
    12         return x;
    13     }
    14 
    15     public double getY() {
    16         return y;
    17     }
    18     public void setX(int x) {
    19         this.x = x;
    20     }
    21     
    22     public void setY(int y) {
    23         this.y = y;
    24     }
    25     
    26     public String toString(double x ,double y){
    27         return x + " " + y;
    28     }
    29 }

    运行结果:

    错误率e1:0.30000000000000004
    权值分布D1:0.07142857142857142 0.07142857142857142 0.07142857142857142 0.07142857142857142 0.07142857142857142 0.07142857142857142 0.16666666666666663 0.16666666666666663 0.16666666666666663 0.07142857142857142
    错误率e2:0.21428571428571427
    权值分布D2:0.04545454545454546 0.04545454545454546 0.04545454545454546 0.1666666666666667 0.1666666666666667 0.1666666666666667 0.10606060606060605 0.10606060606060605 0.10606060606060605 0.04545454545454546
    错误率e3:0.18181818181818185
    权值分布D3:0.12499999999999997 0.12499999999999997 0.12499999999999997 0.10185185185185185 0.10185185185185185 0.10185185185185185 0.0648148148148148 0.0648148148148148 0.0648148148148148 0.12499999999999997
    所得函数: sign[0.4236489301936017 * g[0] + 0.6496414920651304 * g[1] + 0.752038698388137 * g[2]]

  • 相关阅读:
    Wireshark教程
    存储基础知识4——
    java核心技术 要点笔记3
    java核心技术 要点笔记2
    java核心技术 要点笔记1
    java Vamei快速教程22 内存管理和垃圾回收
    java Vamei快速教程21 事件响应
    php框架练习
    php之图片处理类缩略图加水印
    php之框架增加日志记录功能类
  • 原文地址:https://www.cnblogs.com/wn19910213/p/3336294.html
Copyright © 2011-2022 走看看