KNN算法基本的思路是比较好理解的,今天根据它的特点写了一个实例,我会把所有的数据和代码都写在下面供大家参考,不足之处,请指正。谢谢!
update:工程代码全部在本页面中,测试数据已丢失,建议去UCI Dataset中找一个自行测试一下。
几点说明:
1.KNN中的K=5;
2.在计算权重时,采用的是减去函数{1,0.8,0.6,0.4,0.2},当然你也可以采用反函数或高斯函数;
3.5%作为测试集(decision.txt),95%作为训练集(training.txt);
4.在计算costfun之前,对所有的属性进行了归一化,由于这里不知道数据集每个属性代表的含义,所以就一视同仁,实际情况下,应该具体问题具体分析;
XBWKNN.java
package XBWKNN; import java.io.IOException; import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; import java.util.List; /** * KNN算法 * @author XBW * @date 2014年8月16日 */ public class XBWKNN{ public final static int KofKNN=5; public final static double weight[]={1,0.9,0.7,0.4,0.1}; //减法函数y=1-0.2*x /** * knn * @param data * @param ds * @return ans */ public static int knn(Data data,DataSet ds){ int ans = 0; List<Data> dis=calcDis(data,ds); ans=calcKDis(data,dis); return ans; } /** * 计算训练集中所有向量的距离,排序之后取前K个 * @param data * @param ds * @return */ @SuppressWarnings("null") public static List<Data>calcDis(Data data,DataSet ds){ List<Data> anslist =new ArrayList<Data>(); double dx1=data.x1; double dx2=data.x2; double dx3=data.x3; for(int i=0;i<ds.ds.size();i++){ double x1=ds.ds.get(i).x1; double x2=ds.ds.get(i).x2; double x3=ds.ds.get(i).x3; ds.ds.get(i).costfun=Math.sqrt((dx1-x1)*(dx1-x1)+(dx2-x2)*(dx2-x2)+(dx3-x3)*(dx3-x3)); anslist.add(ds.ds.get(i)); } Collections.sort(anslist,new Comparator<Data>(){ public int compare(Data o1, Data o2) { Double s=o1.costfun-o2.costfun; if(s<0) return -1; else return 1; } }); return anslist; } /** * 按一定的权重计算出前K个 * @param data * @param ds * @return */ public static int calcKDis(Data data,List<Data> anslist){ Double[] anstype={0.0,0.0,0.0,0.0}; for(int i=0;i<KofKNN;i++){ if(anslist.get(i).type==1){ anstype[1]+=weight[i]; } else if(anslist.get(i).type==2){ anstype[2]+=weight[i]; } if(anslist.get(i).type==3){ anstype[3]+=weight[i]; } } Double maxt=-1.0; int tag=1; for(int i=1;i<=3;i++){ if(maxt<anstype[i]){ tag=i; maxt=anstype[i]; } } return tag; } public static void main(String[] args) throws IOException{ DataSet ds=new DataSet(); DataTest dt=new DataTest(); int correct=0; for(int i=0;i<dt.dt.size();i++){ Data data=dt.dt.get(i); int result=knn(data,ds); if(result==data.type){ correct++; } } System.out.println("total test num :"+dt.dt.size()); System.out.println("correct test num :"+correct); System.out.println("ratio :"+correct/(double)dt.dt.size()); } }
Datatest.java
package XBWKNN; import java.io.BufferedReader; import java.io.File; import java.io.FileReader; import java.io.IOException; import java.util.ArrayList; import java.util.List; /** * 测试数据 * @author XBW * @date 2014年8月16日 */ public class DataTest{ String defaultpath="D:\MachineLearning\十大算法\KNN\knncode\decision.txt"; List<Data> dt; @SuppressWarnings("null") public DataTest() throws IOException{ List<Data> dset = new ArrayList<Data>(); File ds=new File(defaultpath); @SuppressWarnings("resource") BufferedReader br = new BufferedReader(new FileReader(ds)); String tsing; double max1=-1; double max2=-1; double max3=-1; while((tsing=br.readLine())!=null){ String[] dlist=tsing.split(" "); Data data=new Data(); data.x1=Double.parseDouble(dlist[0]); data.x2=Double.parseDouble(dlist[1]); data.x3=Double.parseDouble(dlist[2]); data.type=Integer.parseInt(dlist[3]); dset.add(data); if(data.x1>max1){ max1=data.x1; } if(data.x2>max2){ max2=data.x2; } if(data.x3>max3){ max3=data.x3; } } dset=normalization(dset,max1,max2,max3); this.dt=dset; } public List<Data> normalization(List<Data> dset,double m1,double m2,double m3){ for(int i=0;i<dset.size();i++){ dset.get(i).x1/=m1; dset.get(i).x2/=m2; dset.get(i).x3/=m3; } return dset; } }
DataSet.java
package XBWKNN; import java.io.BufferedReader; import java.io.File; import java.io.FileReader; import java.io.IOException; import java.util.ArrayList; import java.util.List; /** * 训练数据 * @author XBW * @date 2014年8月16日 */ public class DataSet{ String defaultpath="D:\MachineLearning\十大算法\KNN\knncode\training.txt"; List<Data> ds; @SuppressWarnings("null") public DataSet() throws IOException{ List<Data> dset =new ArrayList<Data>(); File ds=new File(defaultpath); @SuppressWarnings("resource") BufferedReader br = new BufferedReader(new FileReader(ds)); String tsing; double max1=-1; double max2=-1; double max3=-1; while((tsing=br.readLine())!=null){ String[] dlist=tsing.split(" "); Data data=new Data(); data.x1=Double.parseDouble(dlist[0]); data.x2=Double.parseDouble(dlist[1]); data.x3=Double.parseDouble(dlist[2]); data.type=Integer.parseInt(dlist[3]); dset.add(data); if(data.x1>max1){ max1=data.x1; } if(data.x2>max2){ max2=data.x2; } if(data.x3>max3){ max3=data.x3; } } dset=normalization(dset,max1,max2,max3); this.ds=dset; } public List<Data> normalization(List<Data> dset,double m1,double m2,double m3){ for(int i=0;i<dset.size();i++){ dset.get(i).x1/=m1; dset.get(i).x2/=m2; dset.get(i).x3/=m3; } return dset; } }
Data.java
package XBWKNN; /** * 一条数据 * @author XBW * @date 2014年8月16日 */ public class Data{ Double x1; Double x2; Double x3; Double costfun; int type; }
output: