zoukankan      html  css  js  c++  java
  • K近邻算法

    public class KnnTest 
    {
        public static void readFileToList(String path, List<List<Double>> list)
        {
            BufferedReader br = null;
            
            try {
                br = new BufferedReader(new FileReader(path));
                while (br.ready()) {
                    String line = br.readLine();
                    if (line.trim().isEmpty()) {
                        continue;
                    }
                    String[] tokens = line.split(" ");
                    List<Double> box = new ArrayList<Double>();
                    
                    for (String num : tokens) {
                        box.add(Double.parseDouble(num));
                    }
                    list.add(box);
                }
            }
            catch (IOException ex) {
                ex.printStackTrace();
            }
        }
        
        
        public static void main(String[] args)
        {
            int length = 2;
            String dataFile = "data.txt"; 
            String testFile = "test.txt";
            
            KNN knn = new KNN();
            
            try {
                List<List<Double>> dataList = new ArrayList<List<Double>>();
                List<List<Double>> testList = new ArrayList<List<Double>>();
                
                readFileToList(dataFile, dataList);
                readFileToList(testFile, testList);
                
                for (List<Double> test : testList) {
                    for (Double d : test) {
                        System.out.print(d + " ");
                    }
                    
                    String category = knn.knn(dataList, test, length);
                    System.out.println(Math.round(Float.parseFloat(category)));
                }
            }
            catch (Exception ex) {
                ex.printStackTrace();
            }
        }
    }
    
    
    class KNN
    {
        private static Comparator<Node> comparator = new Comparator<Node>()
        {
            public int compare(Node n1, Node n2)
            {
                if (n1.getDistans() > n2.getDistans()) {
                    return 1;
                }
                return 0;
            }
        };
        
        private int[] getRankNumbers(int n, int max)
        {
            int[] result = new int[n];
            int current = 0;
            
            back: for (int i = 0; i < n; i++) {
                current = (int) (Math.random() * max);
                
                for (int j = 0; j < i; j++) {
                    if (current == result[j]) {
                        i--;
                        continue back;
                    }
                }
                
                result[i] = current;
            }
            
            return result;
        }
        
        public String knn(List<List<Double>> example, List<Double> test, int k)
        {
            PriorityQueue<Node> pq = new PriorityQueue<Node>(k, comparator);
            int[] rand = getRankNumbers(k, example.size());
            
            for (int i = 0; i < k; i++) {
                List<Double> list = example.get(rand[i]);
                String category = list.get(list.size() - 1).toString();
                Node node = new Node(rand[i], calDistans(test, list), category);
                pq.add(node);
            }
            
            for (int i = 0; i < example.size(); i++) {
                List<Double> list = example.get(i);
                double distans = calDistans(test, list);
                Node node = pq.peek();
                if (node.getDistans() > distans) {
                    pq.remove();
                    pq.add(new Node(i, distans, list.get(list.size() - 1).toString()));
                }
            }
            
            return getMostCategory(pq);
        }
        
        private String getMostCategory(PriorityQueue<Node> pq)
        {
            Map<String, Integer> rankMapping = new HashMap<String, Integer>(pq.size(), 1);
            
            for (int i = 0; i < pq.size(); i++) {
                Node node = pq.remove();
                String category = node.getCategory();
                if (rankMapping.containsKey(category)) {
                    rankMapping.put(category, rankMapping.get(category) + 1);
                }
                else {
                    rankMapping.put(category, 1);
                }
            }
            
            int index = -1;
            int count = 0;
            
            Object[] data = rankMapping.keySet().toArray();
            for (int i = 0; i < data.length; i++) {
                if (rankMapping.get(data[i]) > count) {
                    index = i;
                    count = rankMapping.get(data[i]);
                }
            }
            
            return data[index].toString();
        }
        
        
        public double calDistans(List<Double> list1, List<Double> list2)
        {
            double result = 0.00;
            
            for (int i = 0; i < list1.size(); i++) {
                result += (list1.get(i) - list2.get(i)) * (list1.get(i) - list2.get(i));
            }
            
            return result;
        }
        
        
        static class Node
        {
            private int index;
            private double distans;
            private String category;
    
            public Node(int index, double distans, String category)
            {
                this.index = index;
                this.distans = distans;
                this.category = category;
            }
    
            public int getIndex() 
            {
                return index;
            }
    
            public void setIndex(int index) 
            {
                this.index = index;
            }
    
            public double getDistans() 
            {
                return distans;
            }
    
            public void setDistans(double distans) 
            {
                this.distans = distans;
            }
    
            public String getCategory() 
            {
                return category;
            }
    
            public void setCategory(String category) 
            {
                this.category = category;
            }
        }    
    }
  • 相关阅读:
    <转> 百度空间 最大子图形问题详解
    Hdu 1124 Factorial
    Uva 457 Linear Cellular Automata
    求01矩阵中的最大的正方形面积
    【HYSBZ】1036 树的统计Count
    【SPOJ】375 Query on a tree
    【POJ】3580 SuperMemo
    【CodeForces】191C Fools and Roads
    【FOJ】2082 过路费
    【HDU】3726 Graph and Queries
  • 原文地址:https://www.cnblogs.com/rilley/p/2690098.html
Copyright © 2011-2022 走看看