zoukankan      html  css  js  c++  java
  • JAVA实现KNN分类

    转载请注明出处:http://blog.csdn.net/xiaojimanman/article/details/51064307

    http://www.llwjy.com/blogdetail/f74b497c2ad6261b0ea651454b97a390.html

    个人博客站已经上线了,网址 www.llwjy.com ~欢迎各位吐槽~

    -------------------------------------------------------------------------------------------------

          在開始之前先打一个小小的广告,自己创建一个QQ群:321903218,点击链接加入群【Lucene案例开发】,主要用于交流怎样使用Lucene来创建站内搜索后台,同一时候还会不定期的在群内开相关的公开课,感兴趣的童鞋能够加入交流。


          KNN算法又叫近邻算法,是数据挖掘中一种经常使用的分类算法,接单的介绍KNN算法的核心思想就是:寻找与目标近期的K个个体,这些样本属于类别最多的那个类别就是目标的类别。比方K为7,那么我们就从数据中找到和目标近期(或者类似度最高)的7个样本,加入这7个样本相应的类别分别为A、B、C、A、A、A、B,那么目标属于的分类就是A(由于这7个样本中属于A类别的样本个数最多)。


    算法实现

    一、训练数据格式定义

          以下就简单的介绍下怎样用JAVA来实现KNN分类,首先我们须要存储训练集(包含属性以及相应的类别),这里我们对未知的属性使用泛型。类别我们使用字符串存储。

     /**  
     *@Description:  KNN分类模型中一条记录的存储格式
     */ 
    package com.lulei.datamining.knn.bean;  
      
    public class KnnValueBean<T>{
    	private T value;//记录值
    	private String typeId;//分类ID
    	
    	public KnnValueBean(T value, String typeId) {
    		this.value = value;
    		this.typeId = typeId;
    	}
    
    	public T getValue() {
    		return value;
    	}
    
    	public void setValue(T value) {
    		this.value = value;
    	}
    
    	public String getTypeId() {
    		return typeId;
    	}
    
    	public void setTypeId(String typeId) {
    		this.typeId = typeId;
    	}
    }
    

    二、K个近期邻类别数据格式定义

          在统计得到K个近期邻中,我们须要记录前K个样本的分类以及相应的类似度,我们这里使用例如以下数据格式:

     /**  
     *@Description: K个近期邻的类别得分
     */ 
    package com.lulei.datamining.knn.bean;  
      
    public class KnnValueSort {
    	private String typeId;//分类ID
    	private double score;//该分类得分
    	
    	public KnnValueSort(String typeId, double score) {
    		this.typeId = typeId;
    		this.score = score;
    	}
    	public String getTypeId() {
    		return typeId;
    	}
    	public void setTypeId(String typeId) {
    		this.typeId = typeId;
    	}
    	public double getScore() {
    		return score;
    	}
    	public void setScore(double score) {
    		this.score = score;
    	}
    }
    

    三、KNN算法基本属性

          在KNN算法中,最重要的一个指标就是K的取值,因此我们在基类中须要设置一个属性K以及设置一个数组用于存储已知分类的数据。

    private List<KnnValueBean> dataArray;
    private int K = 3;

    四、加入已知分类数据

          在使用KNN分类之前,我们须要先向当中加入我们已知分类的数据。我们后面就是使用这些数据来预測未知数据的分类。

    /**
     * @param value
     * @param typeId
     * @Author:lulei  
     * @Description: 向模型中加入记录
     */
    public void addRecord(T value, String typeId) {
    	if (dataArray == null) {
    		dataArray = new ArrayList<KnnValueBean>();
    	}
    	dataArray.add(new KnnValueBean<T>(value, typeId));
    }

    五、两个样本之间的类似度(或者距离)

          在KNN算法中,最重要的一个方法就是怎样确定两个样本之间的类似度(或者距离)。由于这里我们使用的是泛型。并没有办法确定两个对象之间的类似度。一次这里我们把它设置为抽象方法,让子类来实现。这里我们方法定义为类似度,也就是返回值越大。两者越类似,之间的距离越短

    /**
     * @param o1
     * @param o2
     * @return
     * @Author:lulei  
     * @Description: o1 o2之间的类似度
     */
    public abstract double similarScore(T o1, T o2);

    六、获取近期的K个样本的分类

          KNN算法的核心思想就是找到近期的K个近邻,因此这一步也是整个算法的核心部分。

    这里我们使用数组来保存类似度最大的K个样本的分类和类似度,在计算的过程中通过循环遍历全部的样本,数组保存截至当前计算点最类似的K个样本相应的类别和类似度。详细实现例如以下:

    /**
     * @param value
     * @return
     * @Author:lulei  
     * @Description: 获取距离近期的K个分类
     */
    private KnnValueSort[] getKType(T value) {
    	int k = 0;
    	KnnValueSort[] topK = new KnnValueSort[K];
    	for (KnnValueBean<T> bean : dataArray) {
    		double score = similarScore(bean.getValue(), value);
    		if (k == 0) {
    			//数组中的记录个数为0是直接加入
    			topK[k] = new KnnValueSort(bean.getTypeId(), score);
    			k++;
    		} else {
    			if (!(k == K && score < topK[k -1].getScore())){
    				int i = 0;
    				//找到要插入的点
    				for (; i < k && score < topK[i].getScore(); i++);
    				int j = k - 1;
    				if (k < K) {
    					j = k;
    					k++;
    				}
    				for (; j > i; j--) {
    					topK[j] = topK[j - 1];
    				}
    				topK[i] = new KnnValueSort(bean.getTypeId(), score);
    			}
    		}
    	}
    	return topK;
    }

    七、统计K个样本出现次数最多的类别

          这一步就是一个简单的计数,统计K个样本中出现次数最多的分类,该分类就是我们要预測的目标数据的分类。

    /**
     * @param value
     * @return
     * @Author:lulei  
     * @Description: KNN分类推断value的类别
     */
    public String getTypeId(T value) {
    	KnnValueSort[] array = getKType(value);
    	HashMap<String, Integer> map = new HashMap<String, Integer>(K);
    	for (KnnValueSort bean : array) {
    		if (bean != null) {
    			if (map.containsKey(bean.getTypeId())) {
    				map.put(bean.getTypeId(), map.get(bean.getTypeId()) + 1);
    			} else {
    				map.put(bean.getTypeId(), 1);
    			}
    		}
    	}
    	String maxTypeId = null;
    	int maxCount = 0;
    	Iterator<Entry<String, Integer>> iter = map.entrySet().iterator();
    	while (iter.hasNext()) {
    		Entry<String, Integer> entry = iter.next();
    		if (maxCount < entry.getValue()) {
    			maxCount = entry.getValue();
    			maxTypeId = entry.getKey();
    		}
    	}
    	return maxTypeId;
    }

          到如今为止KNN分类的抽象基类已经编写完毕,在測试之前我们先多说几句,KNN分类是统计K个样本中出现次数最多的分类,这样的在有些情况下并非特别合理。比方K=5。前5个样本相应的分类分别为A、A、B、B、B。相应的类似度得分分别为10、9、2、2、1。假设使用上面的方法,那预測的分类就是B。可是看这些数据,预測的分类是A感觉更合理。基于这样的情况,自己对KNN算法提出例如以下优化(这里并不提供代码,仅仅提供简单的思路):在获取最类似K个样本和类似度后。能够对类似度和出现次数K做一种函数运算。比方加权。得到的函数值最大的分类就是目标的预測分类。

    基类源代码

     /**  
     *@Description: KNN分类
     */ 
    package com.lulei.datamining.knn;  
    
    import java.util.ArrayList;
    import java.util.HashMap;
    import java.util.Iterator;
    import java.util.List;
    import java.util.Map.Entry;
    
    import com.lulei.datamining.knn.bean.KnnValueBean;
    import com.lulei.datamining.knn.bean.KnnValueSort;
    import com.lulei.util.JsonUtil;
      
    @SuppressWarnings({"rawtypes"})
    public abstract class KnnClassification<T> {
    	private List<KnnValueBean> dataArray;
    	private int K = 3;
    	
    	public int getK() {
    		return K;
    	}
    	public void setK(int K) {
    		if (K < 1) {
    			throw new IllegalArgumentException("K must greater than 0");
    		}
    		this.K = K;
    	}
    
    	/**
    	 * @param value
    	 * @param typeId
    	 * @Author:lulei  
    	 * @Description: 向模型中加入记录
    	 */
    	public void addRecord(T value, String typeId) {
    		if (dataArray == null) {
    			dataArray = new ArrayList<KnnValueBean>();
    		}
    		dataArray.add(new KnnValueBean<T>(value, typeId));
    	}
    	
    	/**
    	 * @param value
    	 * @return
    	 * @Author:lulei  
    	 * @Description: KNN分类推断value的类别
    	 */
    	public String getTypeId(T value) {
    		KnnValueSort[] array = getKType(value);
    		System.out.println(JsonUtil.parseJson(array));
    		HashMap<String, Integer> map = new HashMap<String, Integer>(K);
    		for (KnnValueSort bean : array) {
    			if (bean != null) {
    				if (map.containsKey(bean.getTypeId())) {
    					map.put(bean.getTypeId(), map.get(bean.getTypeId()) + 1);
    				} else {
    					map.put(bean.getTypeId(), 1);
    				}
    			}
    		}
    		String maxTypeId = null;
    		int maxCount = 0;
    		Iterator<Entry<String, Integer>> iter = map.entrySet().iterator();
    		while (iter.hasNext()) {
    			Entry<String, Integer> entry = iter.next();
    			if (maxCount < entry.getValue()) {
    				maxCount = entry.getValue();
    				maxTypeId = entry.getKey();
    			}
    		}
    		return maxTypeId;
    	}
    	
    	/**
    	 * @param value
    	 * @return
    	 * @Author:lulei  
    	 * @Description: 获取距离近期的K个分类
    	 */
    	private KnnValueSort[] getKType(T value) {
    		int k = 0;
    		KnnValueSort[] topK = new KnnValueSort[K];
    		for (KnnValueBean<T> bean : dataArray) {
    			double score = similarScore(bean.getValue(), value);
    			if (k == 0) {
    				//数组中的记录个数为0是直接加入
    				topK[k] = new KnnValueSort(bean.getTypeId(), score);
    				k++;
    			} else {
    				if (!(k == K && score < topK[k -1].getScore())){
    					int i = 0;
    					//找到要插入的点
    					for (; i < k && score < topK[i].getScore(); i++);
    					int j = k - 1;
    					if (k < K) {
    						j = k;
    						k++;
    					}
    					for (; j > i; j--) {
    						topK[j] = topK[j - 1];
    					}
    					topK[i] = new KnnValueSort(bean.getTypeId(), score);
    				}
    			}
    		}
    		return topK;
    	}
    	
    	/**
    	 * @param o1
    	 * @param o2
    	 * @return
    	 * @Author:lulei  
    	 * @Description: o1 o2之间的类似度
    	 */
    	public abstract double similarScore(T o1, T o2);
    }
    

    详细子类实现

          对于上面介绍的都在KNN分类的抽象基类中,对于实际的问题我们须要继承基类并实现基类中的类似度抽象方法,这里我们做一个简单的实现。

     /**  
     *@Description:     
     */ 
    package com.lulei.datamining.knn.test;  
    
    import com.lulei.datamining.knn.KnnClassification;
    import com.lulei.util.JsonUtil;
      
    public class Test extends KnnClassification<Integer>{
    	
    	@Override
    	public double similarScore(Integer o1, Integer o2) {
    		return -1 * Math.abs(o1 - o2);
    	}
    	
    	/**  
    	 * @param args
    	 * @Author:lulei  
    	 * @Description:  
    	 */
    	public static void main(String[] args) {
    		Test test = new Test();
    		for (int i = 1; i < 10; i++) {
    			test.addRecord(i, i > 5 ?

    "0" : "1"); } System.out.println(JsonUtil.parseJson(test.getTypeId(0))); } }


          这里我们一共加入了1、2、3、4、5、6、7、8、9这9组数据,前5组的类别为1,后4组的类别为0。两个数据之间的类似度为两者之间的差值的绝对值的相反数,以下预測0应该属于的分类,这里K的默认值为3,因此近期的K个样本分别为1、2、3。相应的分类分别为"1"、"1"、"1",由于最后预測的分类为"1"。

    -------------------------------------------------------------------------------------------------
    小福利
    -------------------------------------------------------------------------------------------------
          个人在极客学院上《Lucene案例开发》课程已经上线了。欢迎大家吐槽~

    第一课:Lucene概述

    第二课:Lucene 经常使用功能介绍

    第三课:网络爬虫

    第四课:数据库连接池

    第五课:小说站点的採集

    第六课:小说站点数据库操作

    第七课:小说站点分布式爬虫的实现

    第八课:Lucene实时搜索


  • 相关阅读:
    Mysql热备份
    win10 上安装虚拟机
    SpringMVC AJAX向后台传递数组参数/实体集合
    解决eclipse中tomcat不加载web项目的问题
    Python 基础第九天
    Python 基础第8天(文件管理)
    Python 基础第七天
    Python 基础第六天
    Python 基础第五天
    Python 基础第四天
  • 原文地址:https://www.cnblogs.com/lytwajue/p/7227943.html
Copyright © 2011-2022 走看看