zoukankan      html  css  js  c++  java
  • AdaBoost装袋提升算法

    参开资料:http://blog.csdn.net/haidao2009/article/details/7514787
    更多挖掘算法:https://github.com/linyiqun/DataMiningAlgorithm

    介绍

    在介绍AdaBoost算法之前,需要了解一个类似的算法,装袋算法(bagging),bagging是一种提高分类准确率的算法,通过给定组合投票的方式,获得最优解。比如你生病了,去n个医院看了n个医生,每个医生给你开了药方,最后的结果中,哪个药方的出现的次数多,那就说明这个药方就越有可能性是最由解,这个很好理解。而bagging算法就是这个思想。

    算法原理

    而AdaBoost算法的核心思想还是基于bagging算法,但是他又一点点的改进,上面的每个医生的投票结果都是一样的,说明地位平等,如果在这里加上一个权重,大城市的医生权重高点,小县城的医生权重低,这样通过最终计算权重和的方式,会更加的合理,这就是AdaBoost算法。AdaBoost算法是一种迭代算法,只有最终分类误差率小于阈值算法才能停止,针对同一训练集数据训练不同的分类器,我们称弱分类器,最后按照权重和的形式组合起来,构成一个组合分类器,就是一个强分类器了。算法的只要过程:

    1、对D训练集数据训练处一个分类器Ci

    2、通过分类器Ci对数据进行分类,计算此时误差率

    3、把上步骤中的分错的数据的权重提高,分对的权重降低,以此凸显了分错的数据。为什么这么做呢,后面会做出解释。

    完整的adaboost算法如下


    最后的sign函数是符号函数,如果最后的值为正,则分为+1类,否则即使-1类。

    我们举个例子代入上面的过程,这样能够更好的理解。

    adaboost的实现过程:

      图中,“+”和“-”分别表示两种类别,在这个过程中,我们使用水平或者垂直的直线作为分类器,来进行分类。

      第一步:

      根据分类的正确率,得到一个新的样本分布D,一个子分类器h1

      其中划圈的样本表示被分错的。在右边的途中,比较大的“+”表示对该样本做了加权。

    算法最开始给了一个均匀分布 D 。所以h1 里的每个点的值是0.1。ok,当划分后,有三个点划分错了,根据算法误差表达式得到 误差为分错了的三个点的值之和,所以ɛ1=(0.1+0.1+0.1)=0.3,而ɑ1 根据表达式 的可以算出来为0.42. 然后就根据算法 把分错的点权值变大。如此迭代,最终完成adaboost算法。

      第二步:

      根据分类的正确率,得到一个新的样本分布D3,一个子分类器h2

      第三步:

      得到一个子分类器h3

      整合所有子分类器:

      因此可以得到整合的结果,从结果中看,及时简单的分类器,组合起来也能获得很好的分类效果,在例子中所有的。后面的代码实现时,举出的也是这个例子,可以做对比,这里有一点比较重要,就是点的权重经过大小变化之后,需要进行归一化,确保总和为1.0,这个容易遗忘。

    算法的代码实现

    输入测试数据,与上图的例子相对应(数据格式:x坐标 y坐标 已分类结果):

    1 5 1
    2 3 1
    3 1 -1
    4 5 -1
    5 6 1
    6 4 -1
    6 7 1
    7 6 1
    8 7 -1
    8 2 -1
    

    Point.java

    package DataMining_AdaBoost;
    
    /**
     * 坐标点类
     * 
     * @author lyq
     * 
     */
    public class Point {
    	// 坐标点x坐标
    	private int x;
    	// 坐标点y坐标
    	private int y;
    	// 坐标点的分类类别
    	private int classType;
    	//如果此节点被划错,他的误差率,不能用个数除以总数,因为不同坐标点的权重不一定相等
    	private double probably;
    	
    	public Point(int x, int y, int classType){
    		this.x = x;
    		this.y = y;
    		this.classType = classType;
    	}
    	
    	public Point(String x, String y, String classType){
    		this.x = Integer.parseInt(x);
    		this.y = Integer.parseInt(y);
    		this.classType = Integer.parseInt(classType);
    	}
    
    	public int getX() {
    		return x;
    	}
    
    	public void setX(int x) {
    		this.x = x;
    	}
    
    	public int getY() {
    		return y;
    	}
    
    	public void setY(int y) {
    		this.y = y;
    	}
    
    	public int getClassType() {
    		return classType;
    	}
    
    	public void setClassType(int classType) {
    		this.classType = classType;
    	}
    
    	public double getProbably() {
    		return probably;
    	}
    
    	public void setProbably(double probably) {
    		this.probably = probably;
    	}
    }
    
    AdaBoost.java

    package DataMining_AdaBoost;
    
    import java.io.BufferedReader;
    import java.io.File;
    import java.io.FileReader;
    import java.io.IOException;
    import java.text.MessageFormat;
    import java.util.ArrayList;
    import java.util.HashMap;
    import java.util.Map;
    
    /**
     * AdaBoost提升算法工具类
     * 
     * @author lyq
     * 
     */
    public class AdaBoostTool {
    	// 分类的类别,程序默认为正类1和负类-1
    	public static final int CLASS_POSITIVE = 1;
    	public static final int CLASS_NEGTIVE = -1;
    
    	// 事先假设的3个分类器(理论上应该重新对数据集进行训练得到)
    	public static final String CLASSIFICATION1 = "X=2.5";
    	public static final String CLASSIFICATION2 = "X=7.5";
    	public static final String CLASSIFICATION3 = "Y=5.5";
    
    	// 分类器组
    	public static final String[] ClASSIFICATION = new String[] {
    			CLASSIFICATION1, CLASSIFICATION2, CLASSIFICATION3 };
    	// 分类权重组
    	private double[] CLASSIFICATION_WEIGHT;
    
    	// 测试数据文件地址
    	private String filePath;
    	// 误差率阈值
    	private double errorValue;
    	// 所有的数据点
    	private ArrayList<Point> totalPoint;
    
    	public AdaBoostTool(String filePath, double errorValue) {
    		this.filePath = filePath;
    		this.errorValue = errorValue;
    		readDataFile();
    	}
    
    	/**
    	 * 从文件中读取数据
    	 */
    	private void readDataFile() {
    		File file = new File(filePath);
    		ArrayList<String[]> dataArray = new ArrayList<String[]>();
    
    		try {
    			BufferedReader in = new BufferedReader(new FileReader(file));
    			String str;
    			String[] tempArray;
    			while ((str = in.readLine()) != null) {
    				tempArray = str.split(" ");
    				dataArray.add(tempArray);
    			}
    			in.close();
    		} catch (IOException e) {
    			e.getStackTrace();
    		}
    
    		Point temp;
    		totalPoint = new ArrayList<>();
    		for (String[] array : dataArray) {
    			temp = new Point(array[0], array[1], array[2]);
    			temp.setProbably(1.0 / dataArray.size());
    			totalPoint.add(temp);
    		}
    	}
    
    	/**
    	 * 根据当前的误差值算出所得的权重
    	 * 
    	 * @param errorValue
    	 *            当前划分的坐标点误差率
    	 * @return
    	 */
    	private double calculateWeight(double errorValue) {
    		double alpha = 0;
    		double temp = 0;
    
    		temp = (1 - errorValue) / errorValue;
    		alpha = 0.5 * Math.log(temp);
    
    		return alpha;
    	}
    
    	/**
    	 * 计算当前划分的误差率
    	 * 
    	 * @param pointMap
    	 *            划分之后的点集
    	 * @param weight
    	 *            本次划分得到的分类器权重
    	 * @return
    	 */
    	private double calculateErrorValue(
    			HashMap<Integer, ArrayList<Point>> pointMap) {
    		double resultValue = 0;
    		double temp = 0;
    		double weight = 0;
    		int tempClassType;
    		ArrayList<Point> pList;
    		for (Map.Entry entry : pointMap.entrySet()) {
    			tempClassType = (int) entry.getKey();
    
    			pList = (ArrayList<Point>) entry.getValue();
    			for (Point p : pList) {
    				temp = p.getProbably();
    				// 如果划分类型不相等,代表划错了
    				if (tempClassType != p.getClassType()) {
    					resultValue += temp;
    				}
    			}
    		}
    
    		weight = calculateWeight(resultValue);
    		for (Map.Entry entry : pointMap.entrySet()) {
    			tempClassType = (int) entry.getKey();
    
    			pList = (ArrayList<Point>) entry.getValue();
    			for (Point p : pList) {
    				temp = p.getProbably();
    				// 如果划分类型不相等,代表划错了
    				if (tempClassType != p.getClassType()) {
    					// 划错的点的权重比例变大
    					temp *= Math.exp(weight);
    					p.setProbably(temp);
    				} else {
    					// 划对的点的权重比减小
    					temp *= Math.exp(-weight);
    					p.setProbably(temp);
    				}
    			}
    		}
    
    		// 如果误差率没有小于阈值,继续处理
    		dataNormalized();
    
    		return resultValue;
    	}
    
    	/**
    	 * 概率做归一化处理
    	 */
    	private void dataNormalized() {
    		double sumProbably = 0;
    		double temp = 0;
    
    		for (Point p : totalPoint) {
    			sumProbably += p.getProbably();
    		}
    
    		// 归一化处理
    		for (Point p : totalPoint) {
    			temp = p.getProbably();
    			p.setProbably(temp / sumProbably);
    		}
    	}
    
    	/**
    	 * 用AdaBoost算法得到的组合分类器对数据进行分类
    	 * 
    	 */
    	public void adaBoostClassify() {
    		double value = 0;
    		Point p;
    
    		calculateWeightArray();
    		for (int i = 0; i < ClASSIFICATION.length; i++) {
    			System.out.println(MessageFormat.format("分类器{0}权重为:{1}", (i+1), CLASSIFICATION_WEIGHT[i]));
    		}
    		
    		for (int j = 0; j < totalPoint.size(); j++) {
    			p = totalPoint.get(j);
    			value = 0;
    
    			for (int i = 0; i < ClASSIFICATION.length; i++) {
    				value += 1.0 * classifyData(ClASSIFICATION[i], p)
    						* CLASSIFICATION_WEIGHT[i];
    			}
    			
    			//进行符号判断
    			if (value > 0) {
    				System.out
    						.println(MessageFormat.format(
    								"点({0}, {1})的组合分类结果为:1,该点的实际分类为{2}", p.getX(), p.getY(),
    								p.getClassType()));
    			} else {
    				System.out.println(MessageFormat.format(
    						"点({0}, {1})的组合分类结果为:-1,该点的实际分类为{2}", p.getX(), p.getY(),
    						p.getClassType()));
    			}
    		}
    	}
    
    	/**
    	 * 计算分类器权重数组
    	 */
    	private void calculateWeightArray() {
    		int tempClassType = 0;
    		double errorValue = 0;
    		ArrayList<Point> posPointList;
    		ArrayList<Point> negPointList;
    		HashMap<Integer, ArrayList<Point>> mapList;
    		CLASSIFICATION_WEIGHT = new double[ClASSIFICATION.length];
    
    		for (int i = 0; i < CLASSIFICATION_WEIGHT.length; i++) {
    			mapList = new HashMap<>();
    			posPointList = new ArrayList<>();
    			negPointList = new ArrayList<>();
    
    			for (Point p : totalPoint) {
    				tempClassType = classifyData(ClASSIFICATION[i], p);
    
    				if (tempClassType == CLASS_POSITIVE) {
    					posPointList.add(p);
    				} else {
    					negPointList.add(p);
    				}
    			}
    
    			mapList.put(CLASS_POSITIVE, posPointList);
    			mapList.put(CLASS_NEGTIVE, negPointList);
    
    			if (i == 0) {
    				// 最开始的各个点的权重一样,所以传入0,使得e的0次方等于1
    				errorValue = calculateErrorValue(mapList);
    			} else {
    				// 每次把上次计算所得的权重代入,进行概率的扩大或缩小
    				errorValue = calculateErrorValue(mapList);
    			}
    
    			// 计算当前分类器的所得权重
    			CLASSIFICATION_WEIGHT[i] = calculateWeight(errorValue);
    		}
    	}
    
    	/**
    	 * 用各个子分类器进行分类
    	 * 
    	 * @param classification
    	 *            分类器名称
    	 * @param p
    	 *            待划分坐标点
    	 * @return
    	 */
    	private int classifyData(String classification, Point p) {
    		// 分割线所属坐标轴
    		String position;
    		// 分割线的值
    		double value = 0;
    		double posProbably = 0;
    		double negProbably = 0;
    		// 划分是否是大于一边的划分
    		boolean isLarger = false;
    		String[] array;
    		ArrayList<Point> pList = new ArrayList<>();
    
    		array = classification.split("=");
    		position = array[0];
    		value = Double.parseDouble(array[1]);
    
    		if (position.equals("X")) {
    			if (p.getX() > value) {
    				isLarger = true;
    			}
    
    			// 将训练数据中所有属于这边的点加入
    			for (Point point : totalPoint) {
    				if (isLarger && point.getX() > value) {
    					pList.add(point);
    				} else if (!isLarger && point.getX() < value) {
    					pList.add(point);
    				}
    			}
    		} else if (position.equals("Y")) {
    			if (p.getY() > value) {
    				isLarger = true;
    			}
    
    			// 将训练数据中所有属于这边的点加入
    			for (Point point : totalPoint) {
    				if (isLarger && point.getY() > value) {
    					pList.add(point);
    				} else if (!isLarger && point.getY() < value) {
    					pList.add(point);
    				}
    			}
    		}
    
    		for (Point p2 : pList) {
    			if (p2.getClassType() == CLASS_POSITIVE) {
    				posProbably++;
    			} else {
    				negProbably++;
    			}
    		}
    		
    		//分类按正负类数量进行划分
    		if (posProbably > negProbably) {
    			return CLASS_POSITIVE;
    		} else {
    			return CLASS_NEGTIVE;
    		}
    	}
    
    }
    
    调用类Client.java:

    /**
     * AdaBoost提升算法调用类
     * @author lyq
     *
     */
    public class Client {
    	public static void main(String[] agrs){
    		String filePath = "C:\Users\lyq\Desktop\icon\input.txt";
    		//误差率阈值
    		double errorValue = 0.2;
    		
    		AdaBoostTool tool = new AdaBoostTool(filePath, errorValue);
    		tool.adaBoostClassify();
    	}
    }
    输出结果:

    分类器1权重为:0.424
    分类器2权重为:0.65
    分类器3权重为:0.923
    点(1, 5)的组合分类结果为:1,该点的实际分类为1
    点(2, 3)的组合分类结果为:1,该点的实际分类为1
    点(3, 1)的组合分类结果为:-1,该点的实际分类为-1
    点(4, 5)的组合分类结果为:-1,该点的实际分类为-1
    点(5, 6)的组合分类结果为:1,该点的实际分类为1
    点(6, 4)的组合分类结果为:-1,该点的实际分类为-1
    点(6, 7)的组合分类结果为:1,该点的实际分类为1
    点(7, 6)的组合分类结果为:1,该点的实际分类为1
    点(8, 7)的组合分类结果为:-1,该点的实际分类为-1
    点(8, 2)的组合分类结果为:-1,该点的实际分类为-1

    我们可以看到,如果3个分类单独分类,都没有百分百分对,而尽管组合结果之后,全部分类正确。

    我对AdaBoost算法的理解

    到了算法的末尾,有必要解释一下每次分类自后需要把错的点的权重增大,正确的减少的理由了,加入上次分类之后,(1,5)已经分错了,如果这次又分错,由于上次的权重已经提升,所以误差率更大,则代入公式ln(1-误差率/误差率)所得的权重越小,也就是说,如果同个数据,你分类的次数越多,你的权重越小,所以这就造成整体好的分类器的权重会越大,内部就会同时有各种权重的分类器,形成了一种互补的结果,如果好的分类器结果分错 ,可以由若干弱一点的分类器进行弥补。

    AdaBoost算法的应用

    可以运用在诸如特征识别,二分类的一些应用上,与单个模型相比,组合的形式能显著的提高准确率。


  • 相关阅读:
    centos7安装doxygen
    mysql和mariadb支持insert delayed的问题
    Ubuntu用android-ndk-r15c编译boost_1_65_1
    记不住的Android活动的生命周期
    SpringBoot——经典的Hello World【二】
    SpringBoot——报错总结
    SpringBoot——SpringBoot学习记录【一】
    Nginx——配置文件服务下载
    CRAP-API——如何在Linux服务器部署CRAP-API教程
    Linux—— 报错汇总
  • 原文地址:https://www.cnblogs.com/bianqi/p/12184008.html
Copyright © 2011-2022 走看看