zoukankan      html  css  js  c++  java
  • K-最近邻算法

    介绍

    KNN算法全名为k-Nearest Neighbor,就是K最近邻的意思。KNN也是一种分类算法。但是与之前说的决策树分类算法相比,这个算法算是最简单的一个了。算法的主要过程为:

    1、给定一个训练集数据,每个训练集数据都是已经分好类的。
    2、设定一个初始的测试数据a,计算a到训练集所有数据的欧几里得距离,并排序。                       

    3、选出训练集中离a距离最近的K个训练集数据。

    4、比较k个训练集数据,选出里面出现最多的分类类型,此分类类型即为最终测试数据a的分类。

    下面百度百科上的一张简图:


    KNN算法实现

    首先测试数据需要2块,1个是训练集数据,就是已经分好类的数据,比如上图中的非绿色的点。还有一个是测试数据,就是上面的绿点,当然这里的测试数据不会是一个,而是一组。这里的数据与数据之间的距离用数据的特征向量做计算,特征向量可以是多维度的。通过计算特征向量与特征向量之间的欧几里得距离来推算相似度。定义训练集数据trainInput.txt:

    a 1 2 3 4 5 
    b 5 4 3 2 1 
    c 3 3 3 3 3 
    d -3 -3 -3 -3 -3 
    a 1 2 3 4 4 
    b 4 4 3 2 1 
    c 3 3 3 2 4 
    d 0 0 1 1 -2 
    
    待测试数据testInput,只有特征向量值:

    1 2 3 2 4 
    2 3 4 2 1 
    8 7 2 3 5 
    -3 -2 2 4 0 
    -4 -4 -4 -4 -4 
    1 2 3 4 4 
    4 4 3 2 1 
    3 3 3 2 4 
    0 0 1 1 -2 
    下面是主程序:

    package DataMing_KNN;
    
    import java.io.BufferedReader;
    import java.io.File;
    import java.io.FileReader;
    import java.io.IOException;
    import java.util.ArrayList;
    import java.util.Arrays;
    import java.util.Collection;
    import java.util.Collections;
    import java.util.Comparator;
    import java.util.HashMap;
    import java.util.Map;
    
    import org.apache.activemq.filter.ComparisonExpression;
    
    /**
     * k最近邻算法工具类
     * 
     * @author lyq
     * 
     */
    public class KNNTool {
    	// 为4个类别设置权重,默认权重比一致
    	public int[] classWeightArray = new int[] { 1, 1, 1, 1 };
    	// 测试数据地址
    	private String testDataPath;
    	// 训练集数据地址
    	private String trainDataPath;
    	// 分类的不同类型
    	private ArrayList<String> classTypes;
    	// 结果数据
    	private ArrayList<Sample> resultSamples;
    	// 训练集数据列表容器
    	private ArrayList<Sample> trainSamples;
    	// 训练集数据
    	private String[][] trainData;
    	// 测试集数据
    	private String[][] testData;
    
    	public KNNTool(String trainDataPath, String testDataPath) {
    		this.trainDataPath = trainDataPath;
    		this.testDataPath = testDataPath;
    		readDataFormFile();
    	}
    
    	/**
    	 * 从文件中阅读测试数和训练数据集
    	 */
    	private void readDataFormFile() {
    		ArrayList<String[]> tempArray;
    
    		tempArray = fileDataToArray(trainDataPath);
    		trainData = new String[tempArray.size()][];
    		tempArray.toArray(trainData);
    
    		classTypes = new ArrayList<>();
    		for (String[] s : tempArray) {
    			if (!classTypes.contains(s[0])) {
    				// 添加类型
    				classTypes.add(s[0]);
    			}
    		}
    
    		tempArray = fileDataToArray(testDataPath);
    		testData = new String[tempArray.size()][];
    		tempArray.toArray(testData);
    	}
    
    	/**
    	 * 将文件转为列表数据输出
    	 * 
    	 * @param filePath
    	 *            数据文件的内容
    	 */
    	private ArrayList<String[]> fileDataToArray(String filePath) {
    		File file = new File(filePath);
    		ArrayList<String[]> dataArray = new ArrayList<String[]>();
    
    		try {
    			BufferedReader in = new BufferedReader(new FileReader(file));
    			String str;
    			String[] tempArray;
    			while ((str = in.readLine()) != null) {
    				tempArray = str.split(" ");
    				dataArray.add(tempArray);
    			}
    			in.close();
    		} catch (IOException e) {
    			e.getStackTrace();
    		}
    
    		return dataArray;
    	}
    
    	/**
    	 * 计算样本特征向量的欧几里得距离
    	 * 
    	 * @param f1
    	 *            待比较样本1
    	 * @param f2
    	 *            待比较样本2
    	 * @return
    	 */
    	private int computeEuclideanDistance(Sample s1, Sample s2) {
    		String[] f1 = s1.getFeatures();
    		String[] f2 = s2.getFeatures();
    		// 欧几里得距离
    		int distance = 0;
    
    		for (int i = 0; i < f1.length; i++) {
    			int subF1 = Integer.parseInt(f1[i]);
    			int subF2 = Integer.parseInt(f2[i]);
    
    			distance += (subF1 - subF2) * (subF1 - subF2);
    		}
    
    		return distance;
    	}
    
    	/**
    	 * 计算K最近邻
    	 * @param k
    	 * 在多少的k范围内
    	 */
    	public void knnCompute(int k) {
    		String className = "";
    		String[] tempF = null;
    		Sample temp;
    		resultSamples = new ArrayList<>();
    		trainSamples = new ArrayList<>();
    		// 分类类别计数
    		HashMap<String, Integer> classCount;
    		// 类别权重比
    		HashMap<String, Integer> classWeight = new HashMap<>();
    		// 首先讲测试数据转化到结果数据中
    		for (String[] s : testData) {
    			temp = new Sample(s);
    			resultSamples.add(temp);
    		}
    
    		for (String[] s : trainData) {
    			className = s[0];
    			tempF = new String[s.length - 1];
    			System.arraycopy(s, 1, tempF, 0, s.length - 1);
    			temp = new Sample(className, tempF);
    			trainSamples.add(temp);
    		}
    
    		// 离样本最近排序的的训练集数据
    		ArrayList<Sample> kNNSample = new ArrayList<>();
    		// 计算训练数据集中离样本数据最近的K个训练集数据
    		for (Sample s : resultSamples) {
    			classCount = new HashMap<>();
    			int index = 0;
    			for (String type : classTypes) {
    				// 开始时计数为0
    				classCount.put(type, 0);
    				classWeight.put(type, classWeightArray[index++]);
    			}
    			for (Sample tS : trainSamples) {
    				int dis = computeEuclideanDistance(s, tS);
    				tS.setDistance(dis);
    			}
    
    			Collections.sort(trainSamples);
    			kNNSample.clear();
    			// 挑选出前k个数据作为分类标准
    			for (int i = 0; i < trainSamples.size(); i++) {
    				if (i < k) {
    					kNNSample.add(trainSamples.get(i));
    				} else {
    					break;
    				}
    			}
    			// 判定K个训练数据的多数的分类标准
    			for (Sample s1 : kNNSample) {
    				int num = classCount.get(s1.getClassName());
    				// 进行分类权重的叠加,默认类别权重平等,可自行改变,近的权重大,远的权重小
    				num += classWeight.get(s1.getClassName());
    				classCount.put(s1.getClassName(), num);
    			}
    
    			int maxCount = 0;
    			// 筛选出k个训练集数据中最多的一个分类
    			for (Map.Entry entry : classCount.entrySet()) {
    				if ((Integer) entry.getValue() > maxCount) {
    					maxCount = (Integer) entry.getValue();
    					s.setClassName((String) entry.getKey());
    				}
    			}
    
    			System.out.print("测试数据特征:");
    			for (String s1 : s.getFeatures()) {
    				System.out.print(s1 + " ");
    			}
    			System.out.println("分类:" + s.getClassName());
    		}
    	}
    }
    
    Sample样本数据类:

    package DataMing_KNN;
    
    /**
     * 样本数据类
     * 
     * @author lyq
     * 
     */
    public class Sample implements Comparable<Sample>{
    	// 样本数据的分类名称
    	private String className;
    	// 样本数据的特征向量
    	private String[] features;
    	//测试样本之间的间距值,以此做排序
    	private Integer distance;
    	
    	public Sample(String[] features){
    		this.features = features;
    	}
    	
    	public Sample(String className, String[] features){
    		this.className = className;
    		this.features = features;
    	}
    
    	public String getClassName() {
    		return className;
    	}
    
    	public void setClassName(String className) {
    		this.className = className;
    	}
    
    	public String[] getFeatures() {
    		return features;
    	}
    
    	public void setFeatures(String[] features) {
    		this.features = features;
    	}
    
    	public Integer getDistance() {
    		return distance;
    	}
    
    	public void setDistance(int distance) {
    		this.distance = distance;
    	}
    
    	@Override
    	public int compareTo(Sample o) {
    		// TODO Auto-generated method stub
    		return this.getDistance().compareTo(o.getDistance());
    	}
    	
    }
    
    测试场景类:

    /**
     * k最近邻算法场景类型
     * @author lyq
     *
     */
    public class Client {
    	public static void main(String[] args){
    		String trainDataPath = "C:\Users\lyq\Desktop\icon\trainInput.txt";
    		String testDataPath = "C:\Users\lyq\Desktop\icon\testinput.txt";
    		
    		KNNTool tool = new KNNTool(trainDataPath, testDataPath);
    		tool.knnCompute(3);
    		
    	}
    	
    
    
    }
    
    执行的结果为:

    测试数据特征:1 2 3 2 4 分类:a
    测试数据特征:2 3 4 2 1 分类:c
    测试数据特征:8 7 2 3 5 分类:b
    测试数据特征:-3 -2 2 4 0 分类:a
    测试数据特征:-4 -4 -4 -4 -4 分类:d
    测试数据特征:1 2 3 4 4 分类:a
    测试数据特征:4 4 3 2 1 分类:b
    测试数据特征:3 3 3 2 4 分类:c
    测试数据特征:0 0 1 1 -2 分类:d

    程序的输出结果如上所示,如果不相信的话可以自己动手计算进行验证。

    KNN算法的注意点:

    1、knn算法的训练集数据必须要相对公平,各个类型的数据数量应该是平均的,否则当A数据由1000个B数据由100个,到时无论如何A数据的样本还是占优的。

    2、knn算法如果纯粹凭借分类的多少做判断,还是可以继续优化的,比如近的数据的权重可以设大,最后根据所有的类型权重和进行比较,而不是单纯的凭借数量。

    3、knn算法的缺点是计算量大,这个从程序中也应该看得出来,里面每个测试数据都要计算到所有的训练集数据之间的欧式距离,时间复杂度就已经为O(n*n),如果真实数据的n非常大,这个算法的开销的确态度,所以KNN不适合大规模数据量的分类。

    KNN算法编码时遇到的困难:

    按理来说这么简单的KNN算法本应该是没有多少的难度,但是在多欧式距离的排序上被深深的坑了一段时间,本人起初用Collections.sort(list)的方式进行按距离排序,也把Sample类实现了Compareable接口,但是排序就是不变,最后才知道,distance的int类型要改为Integer引用类型,在compareTo重载方法中调用distance的.CompareTo()方法就成功了,这个小细节平时没注意,难道属性的比较最终一定要调用到引用类型的compareTo()方法?这个小问题竟然花费了我一段时间,最后仔细的比较了一下网上的例子最后才发现......

  • 相关阅读:
    204. Count Primes (Integer)
    203. Remove Linked List Elements (List)
    202. Happy Number (INT)
    201. Bitwise AND of Numbers Range (Bit)
    200. Number of Islands (Graph)
    199. Binary Tree Right Side View (Tree, Stack)
    198. House Robber(Array; DP)
    191. Number of 1 Bits (Int; Bit)
    190. Reverse Bits (Int; Bit)
    189. Rotate Array(Array)
  • 原文地址:https://www.cnblogs.com/bianqi/p/12184059.html
Copyright © 2011-2022 走看看