zoukankan      html  css  js  c++  java
  • knn分类算法学习

        K最近邻(k-Nearest Neighbor,KNN)分类算法,是一个理论上比较成熟的方法,也是最简单的机器学习算法之一。该方法的思路是:如果一个样本在特征空间中的k个最相似(即特征空间中最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类别。KNN算法中,所选择的邻居都是已经正确分类的对象。该方法在定类决策上只依据最邻近的一个或者几个样本的类别来决定待分样本所属的类别。 KNN方法虽然从原理上也依赖于极限定理,但在类别决策时,只与极少量的相邻样本有关。由于KNN方法主要靠周围有限的邻近的样本,而不是靠判别类域的方法来确定所属类别的,因此对于类域的交叉或重叠较多的待分样本集来说,KNN方法较其他方法更为适合。

       算法流程:    

      1. 准备数据,对数据进行预处理
      2. 选用合适的数据结构存储训练数据和测试元组
      3. 设定参数,如k
      4.维护一个大小为k的的按距离由大到小的优先级队列,用于存储最近邻训练元组。随机从训练元组中选取k个元组作为初始的最近邻元组,分别计算测试元组到这k个元组的距离,将训练元组标号和距离存入优先级队列
      5. 遍历训练元组集,计算当前训练元组与测试元组的距离,将所得距离L 与优先级队列中的最大距离Lmax
      6. 进行比较。若L>=Lmax,则舍弃该元组,遍历下一个元组。若L < Lmax,删除优先级队列中最大距离的元组,将当前训练元组存入优先级队列。
      7. 遍历完毕,计算优先级队列中k 个元组的多数类,并将其作为测试元组的类别。
      8. 测试元组集测试完毕后计算误差率,继续设定不同的k值重新进行训练,最后取误差率最小的k 值。
    主要代码:
         使用优先级队列求距离最近的近邻。
    /**
     * 小顶堆求topN
     */
    public class MinHeapPriorityQueue<T extends Comparable<T>> {
        private PriorityQueue<T> queue;
        private int maxSize;
        
        /**
         * @param maxSize
         */
        public MinHeapPriorityQueue(int maxSize) {
            this(maxSize, new Comparator<T>() {
                @Override
                public int compare(T o1, T o2) {
                    return o1.compareTo(o2);
                }
            });
        }
        
        public MinHeapPriorityQueue(int maxSize, Comparator<T> comparator) {
            this.maxSize = maxSize;
            this.queue = new PriorityQueue<>(maxSize, comparator);
        }
        
        public synchronized void insert(T t) {
            if (queue.size() < maxSize) {
                queue.add(t);
            } else {
                T tmp = queue.peek();
                if (t.compareTo(tmp) > 0) {
                    queue.poll();
                    queue.add(t);
                }
            }
        }
        
        public synchronized List<T> sortList() {
            List<T> list = new LinkedList<>(queue);
            Collections.sort(list, new Comparator<T>() {
                @Override
                public int compare(T o1, T o2) {
                    return o2.compareTo(o1);
                }
            });
            return list;
        }
    
        public synchronized List<T> getList(){
            List<T> list = new LinkedList<>(queue);
            return list;
        }
    
        public static double format(double d, int n) {
            double p = Math.pow(10, n);
            return Math.round(d * p) / p;
        }
    
    
    
        
        public static void main(String[] args) {
            MinHeapPriorityQueue<Double> queue = new MinHeapPriorityQueue<>(3);
            Random r = new Random();
            StringBuffer buf = new StringBuffer();
            for (int i = 0; i < 20; i++) {
                double rd = format(r.nextDouble(), 3);
                queue.insert(rd);
                buf.append(rd);
                if (i != 19)
                    buf.append(", ");
            }
            System.out.println("buff: " + buf.toString());
            System.out.println("list: " + queue.sortList());
        }
    }

        knn算法实现:

    public class KNN {
    
        public String knn(List<List<Double>> datas, List<Double> testData, int k) {
            MinHeapPriorityQueue queue = new MinHeapPriorityQueue(k);
            for (int i = 0; i < datas.size(); i++) {
                List<Double> t = datas.get(i);
                double distance = calDistance(t, testData);
                queue.insert(new TrainTuple(i, distance, t.get(t.size() - 1).toString()));
            }
            return getMostClass(queue);
        }
    
        /**
         * 计算测试数据和训练数据元组的距离
         *
         * @param trainData
         * @param testData
         * @return
         */
        private double calDistance(List<Double> trainData, List<Double> testData) {
            double sum = 0d;
            double distance = 0d;
            for (int i = 0; i < trainData.size() - 1 ; i++) {
                sum += (trainData.get(i) - testData.get(i)) * (trainData.get(i) - testData.get(i));
            }
            distance = Math.sqrt(sum);
            return distance;
        }
    
        /**
         * 获取所得到的k个最近邻元组的多数类别
         *
         * @param queue
         * @return 多数类别名称
         */
        private String getMostClass(MinHeapPriorityQueue queue) {
            Map<String, Integer> classCountMap = new HashMap<>();
            List<TrainTuple> arrayList = queue.getList();
            for (int i = 0; i < arrayList.size(); i++) {
                TrainTuple tuple = arrayList.get(i);
                String classify = tuple.getClassify();
                if(classCountMap.containsKey(classify)){
                    classCountMap.put(tuple.getClassify(),classCountMap.get(classify) + 1);
                }else{
                    classCountMap.put(classify,1);
                }
            }
            int maxIndex = -1;
            int maxCount = 0;
            Object[] classes = classCountMap.keySet().toArray();
            for (int i = 0; i < classes.length; i++) {
                if (classCountMap.get(classes[i]) > maxCount) {
                    maxIndex = i;
                    maxCount = classCountMap.get(classes[i]);
                }
            }
            return classes[maxIndex].toString();
        }
    
    }

    具体的代码实现可以参考:https://github.com/yl897958450/knn

    转载请注明出处。

    http://www.cnblogs.com/ylcoder/
  • 相关阅读:
    BZOJ 2743: [HEOI2012]采花( 离线 + BIT )
    BZOJ 1031: [JSOI2007]字符加密Cipher( 后缀数组 )
    BZOJ 1717: [Usaco2006 Dec]Milk Patterns 产奶的模式( 二分答案 + 后缀数组 )
    HDU 2602 Find a way BFS搜索
    HDU 1495 非常可乐 BFS搜索
    UVA 11624 Fire! BFS搜索
    FZU2150 Fire Game BFS搜索
    POJ3414 Pots BFS搜素
    POJ3087 Shuffle'm Up 简单模拟
    POJ 3126 Prime Path BFS搜索
  • 原文地址:https://www.cnblogs.com/ylcoder/p/6285006.html
Copyright © 2011-2022 走看看