zoukankan      html  css  js  c++  java
  • KNN算法及java实现

    public class KNNNode {
    
        private int index;// 元祖标号
        private double distance;// 与测试元祖的距离
        private String c;// 所属类别
    
        public KNNNode(int index, double distance, String c) {
    
            super();
            this.index = index;
            this.distance = distance;
            this.c = c;
    
        }
    
        public int getIndex() {
            return index;
        }
    
        public void setIndex(int index) {
            this.index = index;
        }
    
        public double getDistance() {
            return distance;
        }
    
        public void setDistance(double distance) {
            this.distance = distance;
        }
        
        
        public String getC(){
            return c;
        }
        
        public void setC(){
            this.c=c;
        }
    }
    import java.util.ArrayList;
    import java.util.Comparator;
    import java.util.HashMap;
    import java.util.List;
    import java.util.Map;
    import java.util.PriorityQueue;
    
    //KNN算法主体
    public class KNN {
    //设置优先级队列的函数,距离越大,优先级越高
        private Comparator<KNNNode> comparator =new Comparator<KNNNode>(){
            public int compare(KNNNode o1,KNNNode o2){
                if(o1.getDistance()>=o2.getDistance()){
                    return -1;
                }
                else{
                    return 1;
                }
            }
    };
        
        /**
         * 获取K个不同的随机数
         * @param k随机数的个数
         * @param max随机数最大的范围
         * @return 生成随机数数组
         */
        
        public List<Integer> getRandKNum(int k,int max){
            List<Integer> rand=new ArrayList<Integer>(k);
            for(int i=0;i<k;i++){
                int temp=(int)(Math.random()*max);
                if(!rand.contains(temp)){
                    rand.add(temp);
                }
                else{
                    i--;
                }
            }
            
            return rand;
        }
        
        /**
         * 计算测试元祖和训练元组之间的距离
         * @param d1测试元祖
         * @param d2训练元祖
         * @return 距离值
         */
        
        public double calDistance(List<Double> d1,List<Double>d2){
            double distance=0.0;
            for(int i=0;i<d1.size();i++){
                distance+=(d1.get(i)-d2.get(i))*(d1.get(i)-d2.get(i));
            }
            
            return distance;
            
        }
        
        /**
         * 执行Knn算法,获取测试元组的类别
         * @param datas 训练数据集
         * @param 测试元组
         * @param k 设定的k值
         * @return  测试元组的类别
         */
        
        
        public String knn(List<List<Double>> datas,List<Double> testData,int k){
            PriorityQueue<KNNNode> pq=new PriorityQueue<KNNNode>(k,comparator);
            List<Integer> randNum=getRandKNum(k,datas.size());
            for(int i=0;i<k;i++){
                int index=randNum.get(i);
                List<Double> currData=datas.get(index);
                String c=currData.get(currData.size()-1).toString();
                KNNNode node=new KNNNode(index, calDistance(testData,currData),c);
                pq.add(node);
            }
            for (int i = 0; i < datas.size(); i++) {
                List<Double> t=datas.get(i);
                double distance=calDistance(testData, t);
                KNNNode top=pq.peek();
                if(top.getDistance()>distance){
                    pq.remove();
                    pq.add(new KNNNode(i, distance,t.get(t.size()-1).toString()));
                }
            }
            return getMostClass(pq);
        }
        
        /**
         * 获得所得到的k个最近邻元组的多数类
         * @param pq存储k个最近邻元组的优先级队列
         * @return 多数类的名称
         */
        
        
        private String getMostClass(PriorityQueue<KNNNode> pq){
            Map<String, Integer> classCount=new HashMap<String, Integer>();
            int pqsize=pq.size();
            for(int i=0;i<pqsize;i++){
                KNNNode node=pq.remove();
                String c=node.getC();
                if(classCount.containsKey(c)){
                    
                    
                    classCount.put(c,classCount.get(c)+1);
                }
                else{
                    
                    classCount.put(c,1);
                }
            }
            
            int maxIndex=-1;
            int maxCount=0;
            Object[] classes=classCount.keySet().toArray();
            for(int i=0;i<classes.length;i++){
                if(classCount.get(classes[i])>maxCount){
                    maxIndex=i;
                    maxCount=classCount.get(classes[i]);
                }
            }
            return classes[maxIndex].toString();        
                    
            
        }
        
        
        
        
        
        
        
        
        
        
        
    }
    import java.io.BufferedReader;
    import java.io.File;
    import java.io.FileReader;
    import java.io.IOException;
    import java.util.ArrayList;
    import java.util.List;
    
    public class TestCNN {
        /**
         * 从文件中读取数据
         * 
         * @param datas存储数据的集合对象
         * @param path数据文件的路径
         */
    
        public void read(List<List<Double>> datas, String path) {
            try {
                BufferedReader bReader = new BufferedReader(new FileReader(new File(path)));
                String reader;
    
                reader = bReader.readLine();
                while (reader != null) {
                    String t[] = reader.split(" ");
    
                    ArrayList<Double> list = new ArrayList<Double>();
                    for (int i = 0; i < toString().length(); i++) {
                        list.add(Double.parseDouble(t[i]));
                    }
                    datas.add(list);
                    reader = bReader.readLine();
                }
            } catch (IOException e) {
                // TODO Auto-generated catch block
                e.printStackTrace();
            }
        }
    
        /**
         * 程序执行入口
         * 
         * @param args
         * 
         */
        public static void main(String[] args) {
    
        TestCNN testCNN=new TestCNN();
        String datafile=new File("").getAbsolutePath()+File.separator+"cqudata\datafile.txt";
        String testfile=new File("").getAbsolutePath()+File.separator+"cqudata\testfile.txt";
        List<List<Double>> datas=new ArrayList<List<Double>>();
        List<List<Double>> testDatas=new ArrayList<List<Double>>();
        testCNN.read(datas, datafile);
        testCNN.read(testDatas, testfile);
        
        KNN knn=new KNN();
        for(int i=0;i<testDatas.size();i++){
            
            List <Double> test=testDatas.get(i);
            System.out.println("测试元组为:");
            for (int j = 0; j < test.size(); j++) {
                System.out.println(test.get(j)+" ");
            }
            System.out.println("类别为: ");
            System.out.println(Math.round(Float.parseFloat(knn.knn(datas, test, 3))));
        }
    }
    }

     训练数据:
    1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5 1
    1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8 1
    1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2 1
    1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5 0
    1.0 0.8 1.6 2.1 0.2 2.3 1.6 0.5 1
    1.6 2.1 5.2 1.1 0.8 3.6 2.4 4.5 0
    实验数据:
    1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5
    1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8
    1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2
    1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5
    1.0 0.8 1.6 2.1 0.2 2.3 1.6 0.5
    1.6 2.1 5.2 1.1 0.8 3.6 2.4 4.5
    程序运行结果:
    测试元组: 1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5 类别为: 1
    测试元组: 1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8 类别为: 1
    测试元组: 1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2 类别为: 1
    测试元组: 1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5 类别为: 0

  • 相关阅读:
    python函数及模块
    Python分支结构及循环结构
    python基本的知识
    11.21学习总结
    进度日报28
    进度日报27
    进度日报26
    进度日报25
    进度日报24
    11.14学习总结
  • 原文地址:https://www.cnblogs.com/ilxx1988/p/3308847.html
Copyright © 2011-2022 走看看