zoukankan      html  css  js  c++  java
  • 【Java】K-means算法Java实现以及图像切割

    1.K-means算法简述以及代码原型

    数据挖掘中一个重要算法是K-means。我这里就不做具体介绍。假设感兴趣的话能够移步陈皓的博客:   

     http://www.csdn.net/article/2012-07-03/2807073-k-means 讲得非常好

        总的来讲,k-means聚类须要下面几个步骤:

             ①.初始化数据

             ②.计算初始的中心点,能够随机选择

             ③.计算每一个点到每一个聚类中心的距离。而且划分到距离最短的聚类中心簇中

             ④.计算每一个聚类簇的平均值,这个均值作为新的聚类中心,反复步骤3

             ⑤.假设达到最大循环或者是聚类中心不再变化或者聚类中心变化幅度小于一定范围时,停止循环。

        恩。原理就是这样,超级简单。可是Java算法实现起来代码量并不小。这个代码也不算是全然自己写的啦。也有些借鉴。我把k-means实现封装在了一个类里面,这样就能够随时调用了呢。

          

    import java.util.ArrayList;
    import java.util.Random;
    
    public class kmeans {
    	private int k;//簇数
    	private int m;//迭代次数
    	private int dataSetLength;//数据集长度
    	private ArrayList<double[]> dataSet;//数据集合
    	private ArrayList<double[]> center;//中心链表
    	private ArrayList<ArrayList<double[]>> cluster;//簇
    	private ArrayList<Float> jc;//误差平方和,这个是用来计算中心聚点的移动哦
    	private Random random;
    	
    	//设置原始数据集合
    	public void setDataSet(ArrayList<double[]> dataSet){
    		this.dataSet=dataSet;
    	}
    	//获得簇分组
    	public  ArrayList<ArrayList<double[]>> getCluster(){
    		return this.cluster;
    	}
    	//构造函数,传入要分的簇的数量
    	public kmeans(int k){
    		if(k<=0)
    			k=1;
    		this.k=k;
    	}
    	//初始化
    	private void init(){
    		m=0;
    		random=new Random();
    		if(dataSet==null||dataSet.size()==0)
    			initDataSet();
    		dataSetLength=dataSet.size();
    		if(k>dataSetLength)
    			k=dataSetLength;
    		center=initCenters();
    		cluster=initCluster();
    		jc=new ArrayList<Float>();
    	}
    	//初始化数据集合
    	private void initDataSet(){
    		dataSet=new ArrayList<double[]>();
    		double[][] dataSetArray=new double[][]{{8,2},{3,4},{2,5},{4,2},
    				{7,3},{6,2},{4,7},{6,3},{5,3},{6,3},{6,9},
    				{1,6},{3,9},{4,1},{8,6}};
    		for(int i=0;i<dataSetArray.length;i++)
    			dataSet.add(dataSetArray[i]);
    	}
    	//初始化中心链表,分成几簇就有几个中心
    	private ArrayList<double[]> initCenters(){
    		ArrayList<double[]> center= new ArrayList<double[]>();
    		//生成一个随机数列。
    		int[] randoms=new int[k];
    		boolean flag;
    		int temp=random.nextInt(dataSetLength);
    		randoms[0]=temp;
    		for(int i=1;i<k;i++){
    			flag=true;
    			while(flag){
    				temp=random.nextInt(dataSetLength);
    				int j=0;
    				while(j<i){
    					if(temp==randoms[j])
    						break;
    					j++;
    				}
    				if(j==i)
    					flag=false;
    			}
    			randoms[i]=temp;
    		}
    		for(int i=0;i<k;i++)
    			center.add(dataSet.get(randoms[i]));
    		return center;
    	}
    	//初始化簇集合
    	private ArrayList<ArrayList<double[]>> initCluster(){
    		ArrayList<ArrayList<double[]>> cluster=
    				new ArrayList<ArrayList<double[]>>();
    		for(int i=0;i<k;i++)
    			cluster.add(new ArrayList<double[]>());
    		return cluster;
    	}
    	//计算距离
    	private double distance(double[] element,double[] center){
    		double distance=0.0f;
    		double x=element[0]-center[0];
    		double y=element[1]-center[1];
    		double z=element[2]-center[2];
    		double sum=x*x+y*y+z*z;
    		distance=(double)Math.sqrt(sum);
    		return distance;
    	}
    	//计算最短的距离
    	private int minDistance(double[] distance){
    		double minDistance=distance[0];
    		int minLocation=0;
    		for(int i=0;i<distance.length;i++){
    			if(distance[i]<minDistance){
    				minDistance=distance[i];
    				minLocation=i;
    			}else if(distance[i]==minDistance){
    				if(random.nextInt(10)<5){
    					minLocation=i;
    				}
    			}
    		}
    		return minLocation;
    	}
    	//每一个点分类
    	private void clusterSet(){
    		double[] distance=new double[k];
    		for(int i=0;i<dataSetLength;i++){
    			//计算到每一个中心店的距离
    			for(int j=0;j<k;j++)
    				distance[j]=distance(dataSet.get(i),center.get(j));
    			//计算最短的距离
    			int minLocation=minDistance(distance);
    			//把他加到聚类里
    			cluster.get(minLocation).add(dataSet.get(i));
    		}
    	}
    	//计算新的中心
    	private void setNewCenter(){
    		for(int i=0;i<k;i++){
    			int n=cluster.get(i).size();
    			if(n!=0){
    				double[] newcenter={0,0};
    				for(int j=0;j<n;j++){
    					newcenter[0]+=cluster.get(i).get(j)[0];
    					newcenter[1]+=cluster.get(i).get(j)[1];
    				}
    				newcenter[0]=newcenter[0]/n;
    				newcenter[1]=newcenter[1]/n;
    				center.set(i, newcenter);
    			}
    		}
    	}
    	//求2点的误差平方
    	private double errosquare(double[] element,double[] center){
    		double x=element[0]-center[0];
    		double y=element[1]-center[1];
    		double errosquare=x*x+y*y;
    		return errosquare;
    	}
    	//计算误差平方和准则函数
    	private void countRule(){
    		float jcf=0;
    		for(int i=0;i<cluster.size();i++){
    			for(int j=0;j<cluster.get(i).size();j++)
    				jcf+=errosquare(cluster.get(i).get(j),center.get(i));
    		jc.add(jcf);
    		}
    	}
    	//核心算法
    	private void Kmeans(){
    		//初始化各种变量,随机选定中心。初始化聚类
    		init();
    		//開始循环
    		while(true){
    			//把每一个点分到聚类中去
    			clusterSet();
    			//计算目标函数
    			countRule();
    			//检查误差变化。由于我规定的计算循环次数为50次,所以就不用计算这个啦。你要愿意用也能够,就是慢一点
    			/*
    			if(m!=0){
    				if(jc.get(m)-jc.get(m-1)==0)
    					break;
    			}*/
    			if(m>=50)
    				break;
    			//否则继续生成新的中心
    			setNewCenter();
    			m++;
    			cluster.clear();
    			cluster=initCluster();
    
    		}
    	}
        //仅仅暴露一个接口给外部类
    	public void execute(){
    		System.out.print("start kmeans
    ");
    		Kmeans();
    		System.out.print("kmeans end
    ");
    	}
            //用来在外面打印出来已经分好的聚类
    	public void printDataArray(ArrayList<double[]> data,String dataArrayName){
    		for(int i=0;i<data.size();i++){
    			System.out.print("print:"+dataArrayName+"["+i+"]={"+data.get(i)[0]+","+data.get(i)[1]+"}
    ");
    		}
    		System.out.print("==========================");
    	}
    }
      嗯。代码就是这样。凝视写的非常具体,也都能看得懂。

    以下我给一个測试样例。

    import java.util.ArrayList;
    
    public class Test {
    	public static void main(String[] args){
    		kmeans k=new kmeans(2);
    		ArrayList<double[]> dataSet=new ArrayList<double[]>();
    		dataSet.add(new double[]{2,2,2});
    		dataSet.add(new double[]{1,2,2});
    		dataSet.add(new double[]{2,1,2});
    		dataSet.add(new double[]{1,3,2});
    		dataSet.add(new double[]{3,1,2});
    		dataSet.add(new double[]{-2,-2,-2});
    		dataSet.add(new double[]{-1,-2,-2});
    		dataSet.add(new double[]{-2,-1,-2});
    		dataSet.add(new double[]{-3,-1,-2});
    		dataSet.add(new double[]{-1,-3,-2});
    
    
    		k.setDataSet(dataSet);
    		k.execute();
    		ArrayList<ArrayList<double[]>> cluster=k.getCluster();
    		for(int i=0;i<cluster.size();i++){
    			k.printDataArray(cluster.get(i), "cluster["+i+"]");
    		}
    	}
    }
       没啥难度,也就是输入写初始数据。然后运行k-means在进行分类。最后打印一下。

    这个原型代码非常粗糙。没有加入聚类个数以及循环次数的变量。这些须要自己动手啦。

    2.k-means应用图像切割

      我们能够把k-means聚类放在图像切割上,也就是说把一个颜色的像素分为一类,然后再涂一个颜色。

    像这样。

    左边就是聚类之前的,右边是聚类之后的,看起来还是满炫酷的。事实上聚类算法也是非常easy扩展到这里的。
    有以下四个提示(由于是作业,我决定先不放马,不然到时候作业雷同我的学分就咖喱gaygay了):
       ①.上面的原型代码是对二维的数据进行分类,那我们也知道。一个颜色有RGB三种原色构成,也就是说我们仅仅须要 在二维的基础上。加上一维数据就吼啦。非常easy有木有,改变下数组结构,在距离计算编程三维欧式距离就吼。
       ②.Java有自带的图像处理类,所以读取数据敲击方便。我给一点代码提示哦
    //读取指定文件夹的图片数据,而且写入数组,这个数据要继续处理
    	private int[][] getImageData(String path){
    		BufferedImage bi=null;
    		try{
    			bi=ImageIO.read(new File(path));
    		}catch (IOException e){
    			e.printStackTrace();
    		}
    		int width=bi.getWidth();
    		int height=bi.getHeight();
    		int [][] data=new int[width][height];
    		for(int i=0;i<width;i++)
    			for(int j=0;j<height;j++)
    				data[i][j]=bi.getRGB(i, j);
    		/*測试输出
    		for(int i=0;i<data.length;i++)
    			for(int j=0;j<data[0].length;j++)
    				System.out.println(data[i][j]);*/
    		return data;
    	}
    	//用来处理获取的像素数据,提取我们须要的写入dataItem数组
    	private dataItem[][] InitData(int [][] data){
    		dataItem[][] dataitems=new dataItem[data.length][data[0].length];
    		for(int i=0;i<data.length;i++){
    			for(int j=0;j<data[0].length;j++){
    				dataItem di=new dataItem();
    				Color c=new Color(data[i][j]);
    				di.r=(double)c.getRed();
    				di.g=(double)c.getGreen();
    				di.b=(double)c.getBlue();
    				di.group=1;
    				dataitems[i][j]=di;
    			}
    		}
    		return dataitems;
    	}
              //介货是用来输出图像的
    <pre name="code" class="java">           private void ImagedataOut(String path){
    		Color c0=new Color(255,0,0);
    		Color c1=new Color(0,255,0);
    		Color c2=new Color(0,0,255);
    		Color c3=new Color(128,128,128);
    		BufferedImage nbi=new BufferedImage(source.length,source[0].length,BufferedImage.TYPE_INT_RGB);
    		for(int i=0;i<source.length;i++){
    			for(int j=0;j<source[0].length;j++){
    				if(source[i][j].group==0)
    					nbi.setRGB(i, j, c0.getRGB());
    				else if(source[i][j].group==1)
    					nbi.setRGB(i, j, c1.getRGB());
    				else if(source[i][j].group==2)
    					nbi.setRGB(i, j, c2.getRGB());
    				else if (source[i][j].group==3)
    					nbi.setRGB(i, j, c3.getRGB());
    				//Color c=new Color((int)center[source[i][j].group].r,
    				//		(int)center[source[i][j].group].g,(int)center[source[i][j].group].b);
    				//nbi.setRGB(i, j, c.getRGB());
    			}
    		}
    		try{
    			ImageIO.write(nbi, "jpg", new File(path));
    		}catch(IOException e){
    			e.printStackTrace();
    			}
    	}

    
        非常舒爽。你问我dataItem是啥?等我交完作业我就告诉你。
        ③.有一点不同的是。注意数据格式。胖胖開始用的就是int类型,结果在计算新的聚类中心的时候溢出了呢。。

    。所幸鹏鹏改成了double。可是鹏鹏在计算距离的时候又写错了,最后还是机智的胖胖鹏解决掉了全部的bug。

        ④.注意读取图片的时候保护好数据的顺序,也就是用一个二维数组来存储,这样在写的时候就不用记录像素点的位置,输出的时候也非常方便。
       就是这些。。。

    等我作业交完就来一次完整的代码解说。

  • 相关阅读:
    最小二乘法求回归直线方程的推导过程
    最小二乘法求回归直线方程的推导过程
    Redis过期键的删除策略
    Redis过期键的删除策略
    最小二乘法求回归直线方程的推导过程
    最小二乘法求回归直线方程的推导过程
    不用第三方实现内网穿透
    不用第三方实现内网穿透
    X Redo丢失的4种情况及处理方法
    Problem D: 逆置链式链表(线性表)
  • 原文地址:https://www.cnblogs.com/claireyuancy/p/7388670.html
Copyright © 2011-2022 走看看