zoukankan      html  css  js  c++  java
  • KNN算法java实现代码注释




    源程序共定义了三个class文件,分别是:public class KNNNode;public class KNN;public class TestKNN。


    KNNNode: KNN结点类,用来存储最近邻的k个元组相关的信息 

    KNN:      KNN算法主体类 

    TestKNN: KNN算法测试类 


    1、 TestKNN

    Method: public void read()

    读取文件中的数据,存储为数组的形式(以嵌套链表的形式实现)List<List<Double>> datas



    2、 算法主体:KNN


    3、 定义了一个数据节点数据结构:KNNNode


    package KNN;
    import java.util.ArrayList;
    import java.util.Comparator;
    import java.util.HashMap;
    import java.util.List;
    import java.util.Map;
    import java.util.PriorityQueue;
     * KNN算法主体类
     * @author Rowen
     * @qq 443773264
     * @mail luowen3405@163.com
     * @blog blog.csdn.net/luowen3405
     * @data 2011.03.25
    public class KNN {
         * 设置优先级队列的比较函数,距离越大,优先级越高
        private Comparator<KNNNode> comparator = new Comparator<KNNNode>() {
            public int compare(KNNNode o1, KNNNode o2) {
                if (o1.getDistance() >= o2.getDistance()) {
                    return 1;
                } else {
                    return 0;
         * 获取K个不同的随机数
         * @param k 随机数的个数
         * @param max 随机数最大的范围
         * @return 生成的随机数数组
        public List<Integer> getRandKNum(int k, int max) {
            List<Integer> rand = new ArrayList<Integer>(k);
            for (int i = 0; i < k; i++) {
                int temp = (int) (Math.random() * max);
                if (!rand.contains(temp)) {
                } else {
            return rand;
         * 计算测试元组与训练元组之前的距离
         * @param d1 测试元组
         * @param d2 训练元组
         * @return 距离值
        public double calDistance(List<Double> d1, List<Double> d2) {
            double distance = 0.00;
            for (int i = 0; i < d1.size(); i++) {
                distance += (d1.get(i) - d2.get(i)) * (d1.get(i) - d2.get(i));
            return distance;
         * 执行KNN算法,获取测试元组的类别
         * @param datas 训练数据集
         * @param testData 测试元组
         * @param k 设定的K值
         * @return 测试元组的类别
        public String knn(List<List<Double>> datas, List<Double> testData, int k) {
            PriorityQueue<KNNNode> pq = new PriorityQueue<KNNNode>(k, comparator);//按照自然顺序存储容量为k的优先级队列
            List<Integer> randNum = getRandKNum(k, datas.size()); // 建立一个列表,列表中保存的是训练数据集中实例的个数
            for (int i = 0; i < k; i++) {
                int index = randNum.get(i);
                List<Double> currData = datas.get(index);
                String c = currData.get(currData.size() - 1).toString();
                KNNNode node = new KNNNode(index, calDistance(testData, currData), c);
    //            System.out.println("距离"+node.getDistance()+"测试样例"+index+"k值"+k);
            for (int i = 0; i < datas.size(); i++) {
                List<Double> t = datas.get(i);
                double distance = calDistance(testData, t);
                KNNNode top = pq.peek();
                if (top.getDistance() > distance) {
                    pq.add(new KNNNode(i, distance, t.get(t.size() - 1).toString()));
            return getMostClass(pq);
         * 获取所得到的k个最近邻元组的多数类
         * @param pq 存储k个最近近邻元组的优先级队列
         * @return 多数类的名称
        private String getMostClass(PriorityQueue<KNNNode> pq) {
            Map<String, Integer> classCount = new HashMap<String, Integer>();
            for (int i = 0; i < pq.size(); i++) {
                KNNNode node = pq.remove();
                String c = node.getC();
                if (classCount.containsKey(c)) {
                    classCount.put(c, classCount.get(c) + 1);
                } else {
                    classCount.put(c, 1);
            int maxIndex = -1;
            int maxCount = 0;
            Object[] classes = classCount.keySet().toArray();
            for (int i = 0; i < classes.length; i++) {
                if (classCount.get(classes[i]) > maxCount) {
                    maxIndex = i;
                    maxCount = classCount.get(classes[i]);
            return classes[maxIndex].toString();
    package KNN;
     * KNN结点类,用来存储最近邻的k个元组相关的信息
     * @author Rowen
     * @qq 443773264
     * @mail luowen3405@163.com
     * @blog blog.csdn.net/luowen3405
     * @data 2011.03.25
    public class KNNNode {
        private int index; // 元组标号
        private double distance; // 与测试元组的距离
        private String c; // 所属类别
        public KNNNode(int index, double distance, String c) {
            this.index = index;
            this.distance = distance;
            this.c = c;
        public int getIndex() {
            return index;
        public void setIndex(int index) {
            this.index = index;
        public double getDistance() {
            return distance;
        public void setDistance(double distance) {
            this.distance = distance;
        public String getC() {
            return c;
        public void setC(String c) {
            this.c = c;
    package KNN;
    import java.io.BufferedReader;
    import java.io.File;
    import java.io.FileReader;
    import java.util.ArrayList;
    import java.util.List;
     * KNN算法测试类
     * @author Rowen
     * @qq 443773264
     * @mail luowen3405@163.com
     * @blog blog.csdn.net/luowen3405
     * @data 2011.03.25
    public class TestKNN {
         * 从数据文件中读取数据
         * @param datas 存储数据的集合对象
         * @param path 数据文件的路径
        public void read(List<List<Double>> datas, String path){
            try {
                BufferedReader br = new BufferedReader(new FileReader(new File(path)));
                String data = br.readLine();
                List<Double> l = null;
                while (data != null) {
                    String t[] = data.split("    ");
                    l = new ArrayList<Double>();
                    for (int i = 0; i < t.length; i++) {
    //                    System.out.println(l);
                    data = br.readLine();
            } catch (Exception e) {
         * 程序执行入口
         * @param args
        public static void main(String[] args) {
            TestKNN t = new TestKNN();
            String datafile = new File("").getAbsolutePath() + File.separator + "datafile";
            String testfile = new File("").getAbsolutePath() + File.separator + "testfile";
    //        System.out.println(datafile);
            try {
                List<List<Double>> datas = new ArrayList<List<Double>>();
                List<List<Double>> testDatas = new ArrayList<List<Double>>();
                t.read(datas, datafile);
                t.read(testDatas, testfile);
    //            System.out.println(datas);
                KNN knn = new KNN();
                for (int i = 0; i < testDatas.size(); i++) {
                    List<Double> test = testDatas.get(i);
                    System.out.print("测试元组: ");
                    for (int j = 0; j < test.size(); j++) {
                        System.out.print(test.get(j) + " ");
                    System.out.print("类别为: ");
                    System.out.println(Math.round(Float.parseFloat((knn.knn(datas, test, 3)))));
            } catch (Exception e) {



    0.1887    0.3276    -1
    0.8178    0.7703    1
    0.6761    0.4849    -1
    0.6022    0.6878    -1
    0.1759    0.8217    -1
    0.2607    0.3502    1
    0.2875    0.6713    -1
    0.916    0.7363    -1
    0.1615    0.2564    1
    0.2653    0.9452    1
    0.0911    0.4386    -1
    0.0012    0.3947    -1
    0.4253    0.8419    1
    0.0067    0.4424    -1
    0.8244    0.2089    1
    0.3868    0.3592    -1
    0.9174    0.216    -1
    0.6074    0.3968    -1
    0.068    0.5201    -1
    0.9686    0.9937    1
    0.0908    0.3658    1
    0.3411    0.7691    -1
    0.4609    0.4423    -1
    0.1078    0.4501    1
    0.3445    0.0445    -1
    0.9827    0.7093    1
    0.2428    0.3774    -1
    0.0358    0.1971    -1
    0.82    0.721    1
    0.6718    0.6714    -1
    0.6753    0.2428    -1
    0.7218    0.4299    -1
    0.3127    0.8329    1
    0.0225    0.4162    1
    0.5313    0.2187    1
    0.7847    0.4243    -1
    0.2518    0.6476    1
    0.4076    0.5439    1
    0.9063    0.4587    1
    0.4714    0.2703    -1
    0.7702    0.0196    -1
    0.2548    0.3477    -1
    0.0942    0.5407    1
    0.1917    0.8085    -1
    0.6834    0.7689    -1
    0.1056    0.1097    1
    0.9577    0.5303    -1
    0.9436    0.0938    -1
    0.6959    0.3181    1
    0.4235    0.4484    1
    0.6171    0.6358    1
    0.5309    0.5447    1
    0.8444    0.2621    -1
    0.5762    0.8335    -1
    0.281    0.772    1
    0.224    0.15    -1
    0.4243    0.704    -1
    0.7384    0.7551    -1
    0.4401    0.9329    1
    0.2665    0.7635    1
    0.5944    0.662    1
    0.3225    0.3309    -1
    0.4709    0.2648    1
    0.6444    0.9899    -1
    0.5271    0.9727    1
    0.7788    0.4046    1
    0.7302    0.2362    1
    0.5181    0.6963    -1
    0.5841    0.6073    1
    0.7184    0.5225    1
    0.6999    0.1192    1
    0.3439    0.1194    1
    0.6951    0.7413    -1
    0.611    0.0636    1
    0.4229    0.5822    1
    0.4735    0.8878    -1
    0.2891    0.3935    -1
    0.3196    0.6393    1
    0.1527    0.3912    -1
    0.6385    0.9398    1
    0.2904    0.679    1
    0.4574    0.192    1
    0.3251    0.1058    1
    0.6377    0.5254    -1
    0.5985    0.8699    1
    0.4257    0.862    -1
    0.2691    0.7904    -1
    0.8754    0.1389    1
    0.0336    0.6456    1
    0.6544    0.6473    1


    0.9516    0.0326
    0.9203    0.5612
    0.0527    0.8819
    0.7379    0.6692
    0.2691    0.1904
    0.4228    0.3689
    0.5479    0.4607
    0.9427    0.9816
    0.4177    0.1564
    0.9831    0.8555
    0.3015    0.6448
    0.7011    0.3763
    0.6663    0.1909
    0.5391    0.4283
    0.6981    0.4820
    0.6665    0.1206
    0.1781    0.5895
    0.1280    0.2262
    0.9991    0.3846
    0.1711    0.5830


    测试元组: 0.9516 0.0326 类别为: -1
    测试元组: 0.9203 0.5612 类别为: -1
    测试元组: 0.0527 0.8819 类别为: -1
    测试元组: 0.7379 0.6692 类别为: -1
    测试元组: 0.2691 0.1904 类别为: -1
    测试元组: 0.4228 0.3689 类别为: -1
    测试元组: 0.5479 0.4607 类别为: -1
    测试元组: 0.9427 0.9816 类别为: 1
    测试元组: 0.4177 0.1564 类别为: 1
    测试元组: 0.9831 0.8555 类别为: -1
    测试元组: 0.3015 0.6448 类别为: -1
    测试元组: 0.7011 0.3763 类别为: -1
    测试元组: 0.6663 0.1909 类别为: -1
    测试元组: 0.5391 0.4283 类别为: -1
    测试元组: 0.6981 0.482 类别为: -1
    测试元组: 0.6665 0.1206 类别为: 1
    测试元组: 0.1781 0.5895 类别为: 1
    测试元组: 0.128 0.2262 类别为: 1
    测试元组: 0.9991 0.3846 类别为: -1
    测试元组: 0.1711 0.583 类别为: 1
  • 相关阅读:
    Qt Release 构建时强制包含调试信息
    Spring Kafka(二)操作Topic以及Kafka Tool 2的使用
    Nodejs-JWT token认证:为什么要使用token、token组成(头部、载荷、签名)、jwt使用过程以及token对比session的好处(单点登录、减轻服务器压力、存储信息等)
    [Kotlin] Multi ways to write constuctor in Kotlin
    [CSS] Use CSS Transforms to Create Configurable 3D Cuboids
    [CSS] Use CSS Variables Almost like Boolean Values with Calc (maintainable css)
    [Kotlin] Typecheck with 'is' keyword, 'as' keyword for assert type
    [Kotlin] When to add () and when not to
  • 原文地址:https://www.cnblogs.com/7899-89/p/3620346.html
Copyright © 2011-2022 走看看