K-近邻算法(k-Nearest Neighbor,简称kNN)采用测量不同特征值之间的距离方法进行分类,是一种常用的监督学习方法,其工作机制很简单:给定测试样本,基于某种距离量度找出训练集中与其靠近的k个训练样本,然后基于这k个“邻居”的信息进行预测。kNN算法属于懒惰学习,此类学习技术在训练阶段仅仅是把样本保存起来,训练时间靠小为零,在收到测试样本后在进行处理,所以可知kNN算法的缺点是计算复杂度高、空间复杂度高。但其也有优点,精度高、对异常值不敏感、无数据输入设定。
借张图来说:
当k = 1时目标点有一个class2邻居,根据kNN算法的原理,目标点也为class2。
当k = 5时目标点有两个class2邻居,有三个class1的邻居,根据其原理,目标点的类别为class2。
算法流程
总体来说,KNN分类算法包括以下4个步骤:
①准备数据,对数据进行预处理 。
②计算测试样本点(也就是待分类点)到其他每个样本点的距离。
③对每个距离进行排序,然后选择出距离最小的K个点 。
④对K个点所属的类别进行比较,根据少数服从多数的原则,将测试样本点归入在K个点中占比最高的那一类 。
算法代码
package com.top.knn; import com.top.constants.OrderEnum; import com.top.matrix.Matrix; import com.top.utils.MatrixUtil; import java.util.*; /** * @program: top-algorithm-set * @description: KNN k-临近算法进行分类 * @author: Mr.Zhao * @create: 2020-10-13 22:03 **/ public class KNN { public static Matrix classify(Matrix input, Matrix dataSet, Matrix labels, int k) throws Exception { if (dataSet.getMatrixRowCount() != labels.getMatrixRowCount()) { throw new IllegalArgumentException("矩阵训练集与标签维度不一致"); } if (input.getMatrixColCount() != dataSet.getMatrixColCount()) { throw new IllegalArgumentException("待分类矩阵列数与训练集列数不一致"); } if (dataSet.getMatrixRowCount() < k) { throw new IllegalArgumentException("训练集样本数小于k"); } // 归一化 int trainCount = dataSet.getMatrixRowCount(); int testCount = input.getMatrixRowCount(); Matrix trainAndTest = dataSet.splice(2, input); Map<String, Object> normalize = MatrixUtil.normalize(trainAndTest, 0, 1); trainAndTest = (Matrix) normalize.get("res"); dataSet = trainAndTest.subMatrix(0, trainCount, 0, trainAndTest.getMatrixColCount()); input = trainAndTest.subMatrix(0, testCount, 0, trainAndTest.getMatrixColCount()); // 获取标签信息 List<Double> labelList = new ArrayList<>(); for (int i = 0; i < labels.getMatrixRowCount(); i++) { if (!labelList.contains(labels.getValOfIdx(i, 0))) { labelList.add(labels.getValOfIdx(i, 0)); } } Matrix result = new Matrix(new double[input.getMatrixRowCount()][1]); for (int i = 0; i < input.getMatrixRowCount(); i++) { // 求向量间的欧式距离 Matrix var1 = input.getRowOfIdx(i).extend(2, dataSet.getMatrixRowCount()); Matrix var2 = dataSet.subtract(var1); Matrix var3 = var2.square(); Matrix var4 = var3.sumRow(); Matrix var5 = var4.pow(0.5); // 距离矩阵合并上labels矩阵 Matrix var6 = var5.splice(1, labels); // 将计算出的距离矩阵按照距离升序排序 var6.sort(0, OrderEnum.ASC); // 遍历最近的k个变量 Map<Double, Integer> map = new HashMap<>(); for (int j = 0; j < k; j++) { // 遍历标签种类数 for (Double label : labelList) { if (var6.getValOfIdx(j, 1) == label) { map.put(label, map.getOrDefault(label, 0) + 1); } } } result.setValue(i, 0, getKeyOfMaxValue(map)); } return result; } /** * 取map中值最大的key * * @param map * @return */ private static Double getKeyOfMaxValue(Map<Double, Integer> map) { if (map == null) return null; Double keyOfMaxValue = 0.0; Integer maxValue = 0; for (Double key : map.keySet()) { if (map.get(key) > maxValue) { keyOfMaxValue = key; maxValue = map.get(key); } } return keyOfMaxValue; } }
注:其中的矩阵方法请参考https://github.com/ineedahouse/top-algorithm-set/blob/dev/src/main/java/com/top/matrix/Matrix.java
该算法为本人github项目中的一部分,地址为https://github.com/ineedahouse/top-algorithm-set
如果对你有帮助可以点个star~
参考
《机器学习》-周志华
《机器学习实战》-Peter Harrington