public class KNNNode { private int index;// 元祖标号 private double distance;// 与测试元祖的距离 private String c;// 所属类别 public KNNNode(int index, double distance, String c) { super(); this.index = index; this.distance = distance; this.c = c; } public int getIndex() { return index; } public void setIndex(int index) { this.index = index; } public double getDistance() { return distance; } public void setDistance(double distance) { this.distance = distance; } public String getC(){ return c; } public void setC(){ this.c=c; } }
import java.util.ArrayList; import java.util.Comparator; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.PriorityQueue; //KNN算法主体 public class KNN { //设置优先级队列的函数,距离越大,优先级越高 private Comparator<KNNNode> comparator =new Comparator<KNNNode>(){ public int compare(KNNNode o1,KNNNode o2){ if(o1.getDistance()>=o2.getDistance()){ return -1; } else{ return 1; } } }; /** * 获取K个不同的随机数 * @param k随机数的个数 * @param max随机数最大的范围 * @return 生成随机数数组 */ public List<Integer> getRandKNum(int k,int max){ List<Integer> rand=new ArrayList<Integer>(k); for(int i=0;i<k;i++){ int temp=(int)(Math.random()*max); if(!rand.contains(temp)){ rand.add(temp); } else{ i--; } } return rand; } /** * 计算测试元祖和训练元组之间的距离 * @param d1测试元祖 * @param d2训练元祖 * @return 距离值 */ public double calDistance(List<Double> d1,List<Double>d2){ double distance=0.0; for(int i=0;i<d1.size();i++){ distance+=(d1.get(i)-d2.get(i))*(d1.get(i)-d2.get(i)); } return distance; } /** * 执行Knn算法,获取测试元组的类别 * @param datas 训练数据集 * @param 测试元组 * @param k 设定的k值 * @return 测试元组的类别 */ public String knn(List<List<Double>> datas,List<Double> testData,int k){ PriorityQueue<KNNNode> pq=new PriorityQueue<KNNNode>(k,comparator); List<Integer> randNum=getRandKNum(k,datas.size()); for(int i=0;i<k;i++){ int index=randNum.get(i); List<Double> currData=datas.get(index); String c=currData.get(currData.size()-1).toString(); KNNNode node=new KNNNode(index, calDistance(testData,currData),c); pq.add(node); } for (int i = 0; i < datas.size(); i++) { List<Double> t=datas.get(i); double distance=calDistance(testData, t); KNNNode top=pq.peek(); if(top.getDistance()>distance){ pq.remove(); pq.add(new KNNNode(i, distance,t.get(t.size()-1).toString())); } } return getMostClass(pq); } /** * 获得所得到的k个最近邻元组的多数类 * @param pq存储k个最近邻元组的优先级队列 * @return 多数类的名称 */ private String getMostClass(PriorityQueue<KNNNode> pq){ Map<String, Integer> classCount=new HashMap<String, Integer>(); int pqsize=pq.size(); for(int i=0;i<pqsize;i++){ KNNNode node=pq.remove(); String c=node.getC(); if(classCount.containsKey(c)){ classCount.put(c,classCount.get(c)+1); } else{ classCount.put(c,1); } } int maxIndex=-1; int maxCount=0; Object[] classes=classCount.keySet().toArray(); for(int i=0;i<classes.length;i++){ if(classCount.get(classes[i])>maxCount){ maxIndex=i; maxCount=classCount.get(classes[i]); } } return classes[maxIndex].toString(); } }
import java.io.BufferedReader; import java.io.File; import java.io.FileReader; import java.io.IOException; import java.util.ArrayList; import java.util.List; public class TestCNN { /** * 从文件中读取数据 * * @param datas存储数据的集合对象 * @param path数据文件的路径 */ public void read(List<List<Double>> datas, String path) { try { BufferedReader bReader = new BufferedReader(new FileReader(new File(path))); String reader; reader = bReader.readLine(); while (reader != null) { String t[] = reader.split(" "); ArrayList<Double> list = new ArrayList<Double>(); for (int i = 0; i < toString().length(); i++) { list.add(Double.parseDouble(t[i])); } datas.add(list); reader = bReader.readLine(); } } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } } /** * 程序执行入口 * * @param args * */ public static void main(String[] args) { TestCNN testCNN=new TestCNN(); String datafile=new File("").getAbsolutePath()+File.separator+"cqudata\datafile.txt"; String testfile=new File("").getAbsolutePath()+File.separator+"cqudata\testfile.txt"; List<List<Double>> datas=new ArrayList<List<Double>>(); List<List<Double>> testDatas=new ArrayList<List<Double>>(); testCNN.read(datas, datafile); testCNN.read(testDatas, testfile); KNN knn=new KNN(); for(int i=0;i<testDatas.size();i++){ List <Double> test=testDatas.get(i); System.out.println("测试元组为:"); for (int j = 0; j < test.size(); j++) { System.out.println(test.get(j)+" "); } System.out.println("类别为: "); System.out.println(Math.round(Float.parseFloat(knn.knn(datas, test, 3)))); } } }
训练数据:
1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5 1
1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8 1
1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2 1
1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5 0
1.0 0.8 1.6 2.1 0.2 2.3 1.6 0.5 1
1.6 2.1 5.2 1.1 0.8 3.6 2.4 4.5 0
实验数据:
1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5
1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8
1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2
1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5
1.0 0.8 1.6 2.1 0.2 2.3 1.6 0.5
1.6 2.1 5.2 1.1 0.8 3.6 2.4 4.5
程序运行结果:
测试元组: 1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5 类别为: 1
测试元组: 1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8 类别为: 1
测试元组: 1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2 类别为: 1
测试元组: 1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5 类别为: 0