zoukankan      html  css  js  c++  java
  • java KNN算法

    由于看网上的java有点多,自己写了一份,本人也是初学者,有错误请提出,大家一起学习。

     1 import java.io.BufferedReader;
     2 import java.io.File;
     3 import java.io.FileNotFoundException;
     4 import java.io.FileReader;
     5 import java.io.IOException;
     6 import java.util.*;
     7 
     8 
     9 public class Index {
    10     public static void main(String[] args){
    11         List<List<Double>> Filedatas = new ArrayList<List<Double>>();
    12         List<List<Double>> Testdatas = new ArrayList<List<Double>>();
    13         
    14         readFile(Filedatas,Testdatas);
    15         KNN knn = new KNN();
    16         
    17         for(int i=0;i<Filedatas.size();i++){
    18             String s = knn.comdistance(3,Filedatas,Testdatas.get(i));
    19             print(s,Testdatas.get(i));
    20         }
    21     }
    22     //第4步、打印出结果
    23     private static void print(String s,List<Double> testdata) {
    24         System.out.print("测试数据:");
    25         for(int i=0;i<testdata.size();i++){
    26             System.out.print(testdata.get(i) + " ");
    27         }
    28         int label = Math.round(Float.parseFloat(s));
    29         System.out.println("所属类别:" + label);
    30     }
    31 
    32     //第1.1步、读取文件
    33     private static void readFile(List<List<Double>> Filedatas, List<List<Double>> Testdatas) {
    34         try {
    35             BufferedReader bfd = new BufferedReader(new FileReader(new File("D://a.txt")));
    36             Filedatas = read(bfd,Filedatas);
    37             BufferedReader bft = new BufferedReader(new FileReader(new File("D://b.txt")));
    38             Testdatas = read(bft,Testdatas);
    39         } catch (FileNotFoundException e) {
    40             e.printStackTrace();
    41         }        
    42     }
    43 
    44     //第1.2步、读取文件
    45     private static List<List<Double>> read(BufferedReader bf, List<List<Double>> datas) {
    46         try {
    47             String str = bf.readLine();
    48             while(str != null){
    49                 List<Double> d = new ArrayList<Double>();
    50                 String[] string = str.split(" "); 
    51                 for (String s : string) {
    52                     d.add(Double.parseDouble(s));
    53                 }
    54                 datas.add(d);
    55                 str = bf.readLine();
    56             }
    57         } catch (IOException e) {
    58             e.printStackTrace();
    59         }
    60         return datas;
    61     }
    62 
    63     
    64 }
     1 import java.util.Comparator;
     2 import java.util.HashMap;
     3 import java.util.List;
     4 import java.util.Map;
     5 import java.util.PriorityQueue;
     6 
     7 public class KNN {
     8     
     9     public String comdistance(int k, List<List<Double>> filedatas,List<Double> testdata) {
    10         //第2.1步、对加入queue队列的项进行距离的排序
    11         PriorityQueue<KNNNode> pq = new PriorityQueue<KNNNode>(k,new Comparator<KNNNode>() {    //优先级队列,按照distance的大小进行排列
    12             @Override
    13             public int compare(KNNNode o1, KNNNode o2) {
    14                 if(o1.getDistance() >= o2.getDistance()){
    15                     return -1;
    16                 }
    17                 else{
    18                     return 1;
    19                 }
    20             }
    21         });
    22         //第2.2步、计算测试点与训练点的距离,并add进队列,挑出与测试点距离最近的K个点
    23         for(int i=0;i<k;i++){
    24             double distance = 0;
    25             for(int j=0;j<filedatas.get(i).size()-1;j++){
    26                 distance += (filedatas.get(i).get(j) - testdata.get(j)) * (filedatas.get(i).get(j) - testdata.get(j));
    27             }
    28             KNNNode node = new KNNNode(filedatas.get(i).get(filedatas.get(i).size()-1).toString(),distance);
    29             pq.add(node);
    30         }
    31         for(int i=k;i<filedatas.size();i++){
    32             double distance = 0;
    33             for(int j=0;j<filedatas.get(i).size()-1;j++){
    34                 distance += (filedatas.get(i).get(j) - testdata.get(j)) * (filedatas.get(i).get(j) - testdata.get(j));
    35             }
    36             KNNNode node = new KNNNode(filedatas.get(i).get(filedatas.get(i).size()-1).toString(),distance);
    37             if(pq.peek().getDistance() >= distance ){
    38                 pq.remove();
    39                 pq.add(node);
    40             }
    41         }
    42         String s = decide(pq);
    43         return s;
    44     }
    45     //第3步、把选择好的最近的K个点的类别进行比较,多的即是测试点的类别
    46     private String decide(PriorityQueue<KNNNode> pq) {
    47         Map<String,Integer> m = new HashMap<String,Integer>();
    48         for (KNNNode Node : pq) {
    49             if(m.containsKey(Node.getC())){
    50                 m.put(Node.getC(), m.get(Node.getC()) + 1);
    51             }
    52             else{
    53                 m.put(Node.getC(), 1);
    54             }
    55         }
    56         Object[] o = m.keySet().toArray();
    57 
    58         if(o.length == 1){
    59             return o[0].toString();
    60         }
    61         else{
    62             for(int i=0;i<o.length;i++){
    63                 for(int j=i;j<o.length;j++){
    64                     if(i != j){
    65                         if(m.get(o[i]) > m.get(o[j])){
    66                             return o[i].toString();
    67                         }
    68                         else{
    69                             return o[j].toString();
    70                         }
    71                     }
    72                 }
    73             }
    74         }
    75         return null;
    76     }
    77 }
     1 public class KNNNode {
     2     
     3     private String c;
     4     private double distance;
     5     
     6     public KNNNode(String c, double distance) {
     7         super();
     8         this.c = c;
     9         this.distance = distance;
    10     }
    11     
    12     public String getC() {
    13         return c;
    14     }
    15     public double getDistance() {
    16         return distance;
    17     }
    18     public void setC(String c) {
    19         this.c = c;
    20     }
    21     public void setDistance(double distance) {
    22         this.distance = distance;
    23     }
    24 }

    训练数据:

    1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5 1
    1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8 1
    1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2 1
    1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5 0
    1.0 0.8 1.6 2.1 0.2 2.3 1.6 0.5 1
    1.6 2.1 5.2 1.1 0.8 3.6 2.4 4.5 0

    测试数据:

    1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5
    1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8
    1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2
    1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5
    1.0 0.8 1.6 2.1 0.2 2.3 1.6 0.5
    1.6 2.1 5.2 1.1 0.8 3.6 2.4 4.5

  • 相关阅读:
    mysql之指定为definer的用户不存在
    Hibernate报错:org.hibernate.ObjectNotFoundException: No row with the given identifier exists 解决办法
    MongoDB mongo.exe启动及闪退解决 转载
    pycharm下运行unittest的问题
    mysql大小写敏感与校对规则
    windows7环境下使用pip安装MySQLdb
    HTML中title前面小图标和网站收藏现实的图标
    异步发送的请求---取消操作
    视频文件上传遇到的问题
    vue-devtools 必备开发工具
  • 原文地址:https://www.cnblogs.com/wn19910213/p/3325477.html
Copyright © 2011-2022 走看看