zoukankan      html  css  js  c++  java
  • K-近邻算法kNN

      K-近邻算法(k-Nearest Neighbor,简称kNN)采用测量不同特征值之间的距离方法进行分类,是一种常用的监督学习方法,其工作机制很简单:给定测试样本,基于某种距离量度找出训练集中与其靠近的k个训练样本,然后基于这k个“邻居”的信息进行预测。kNN算法属于懒惰学习,此类学习技术在训练阶段仅仅是把样本保存起来,训练时间靠小为零,在收到测试样本后在进行处理,所以可知kNN算法的缺点是计算复杂度高、空间复杂度高。但其也有优点,精度高、对异常值不敏感、无数据输入设定。

      借张图来说:

    当k = 1时目标点有一个class2邻居,根据kNN算法的原理,目标点也为class2。

    当k = 5时目标点有两个class2邻居,有三个class1的邻居,根据其原理,目标点的类别为class2。

    算法流程

    总体来说,KNN分类算法包括以下4个步骤:

    ①准备数据,对数据进行预处理 。

    ②计算测试样本点(也就是待分类点)到其他每个样本点的距离。

    ③对每个距离进行排序,然后选择出距离最小的K个点 。

    ④对K个点所属的类别进行比较,根据少数服从多数的原则,将测试样本点归入在K个点中占比最高的那一类 。

    算法代码

    package com.top.knn;
    
    import com.top.constants.OrderEnum;
    import com.top.matrix.Matrix;
    import com.top.utils.MatrixUtil;
    
    import java.util.*;
    
    
    /**
     * @program: top-algorithm-set
     * @description: KNN k-临近算法进行分类
     * @author: Mr.Zhao
     * @create: 2020-10-13 22:03
     **/
    public class KNN {
        public static Matrix classify(Matrix input, Matrix dataSet, Matrix labels, int k) throws Exception {
            if (dataSet.getMatrixRowCount() != labels.getMatrixRowCount()) {
                throw new IllegalArgumentException("矩阵训练集与标签维度不一致");
            }
            if (input.getMatrixColCount() != dataSet.getMatrixColCount()) {
                throw new IllegalArgumentException("待分类矩阵列数与训练集列数不一致");
            }
            if (dataSet.getMatrixRowCount() < k) {
                throw new IllegalArgumentException("训练集样本数小于k");
            }
            // 归一化
            int trainCount = dataSet.getMatrixRowCount();
            int testCount = input.getMatrixRowCount();
            Matrix trainAndTest = dataSet.splice(2, input);
            Map<String, Object> normalize = MatrixUtil.normalize(trainAndTest, 0, 1);
            trainAndTest = (Matrix) normalize.get("res");
            dataSet = trainAndTest.subMatrix(0, trainCount, 0, trainAndTest.getMatrixColCount());
            input = trainAndTest.subMatrix(0, testCount, 0, trainAndTest.getMatrixColCount());
    
            // 获取标签信息
            List<Double> labelList = new ArrayList<>();
            for (int i = 0; i < labels.getMatrixRowCount(); i++) {
                if (!labelList.contains(labels.getValOfIdx(i, 0))) {
                    labelList.add(labels.getValOfIdx(i, 0));
                }
            }
    
            Matrix result = new Matrix(new double[input.getMatrixRowCount()][1]);
            for (int i = 0; i < input.getMatrixRowCount(); i++) {
                // 求向量间的欧式距离
                Matrix var1 = input.getRowOfIdx(i).extend(2, dataSet.getMatrixRowCount());
                Matrix var2 = dataSet.subtract(var1);
                Matrix var3 = var2.square();
                Matrix var4 = var3.sumRow();
                Matrix var5 = var4.pow(0.5);
                // 距离矩阵合并上labels矩阵
                Matrix var6 = var5.splice(1, labels);
                // 将计算出的距离矩阵按照距离升序排序
                var6.sort(0, OrderEnum.ASC);
                // 遍历最近的k个变量
                Map<Double, Integer> map = new HashMap<>();
                for (int j = 0; j < k; j++) {
                    // 遍历标签种类数
                    for (Double label : labelList) {
                        if (var6.getValOfIdx(j, 1) == label) {
                            map.put(label, map.getOrDefault(label, 0) + 1);
                        }
                    }
                }
                result.setValue(i, 0, getKeyOfMaxValue(map));
            }
            return result;
        }
    
        /**
         * 取map中值最大的key
         *
         * @param map
         * @return
         */
        private static Double getKeyOfMaxValue(Map<Double, Integer> map) {
            if (map == null)
                return null;
            Double keyOfMaxValue = 0.0;
            Integer maxValue = 0;
            for (Double key : map.keySet()) {
                if (map.get(key) > maxValue) {
                    keyOfMaxValue = key;
                    maxValue = map.get(key);
                }
            }
            return keyOfMaxValue;
        }
    
    }
    KNN

    注:其中的矩阵方法请参考https://github.com/ineedahouse/top-algorithm-set/blob/dev/src/main/java/com/top/matrix/Matrix.java

      升降序枚举类参考https://github.com/ineedahouse/top-algorithm-set/blob/dev/src/main/java/com/top/constants/OrderEnum.java

    该算法为本人github项目中的一部分,地址为https://github.com/ineedahouse/top-algorithm-set

    如果对你有帮助可以点个star~

    参考

    《机器学习》-周志华

    《机器学习实战》-Peter Harrington

     
  • 相关阅读:
    不用再去找rem了,你想要的rem都在这
    linux下ftp配置文件详解
    Linux chmod命令修改文件与文件夹权限命令代码
    如何在linux下开启FTP服务
    解决ftp客户端连接验证报错Server sent passive reply with unroutable address. Using server address instead
    预定义编译器宏
    类的成员变量修饰 const 和static
    【转】svn http://提示svn: Unrecognized URL scheme错误
    EVEREST Ultimate Edition 5.50 正式版 序列号
    [转]Linux下查看文件和文件夹大小
  • 原文地址:https://www.cnblogs.com/MrZhaoyx/p/13989760.html
Copyright © 2011-2022 走看看