zoukankan      html  css  js  c++  java
  • k均值算法

    K均值算法是聚类分析中较常用的一种算法,基本思想如下:

    首先,随机地选择k个对象,每个对象代表一个簇的初始值或中心,对剩余的每个对象,根据其与各个簇均值的距离,将它指派到最相近的簇,然后计算每个簇的新均值。这个过程一直重复,直到准则函数收敛。

    关于距离,有几种不同的距离公式:

    求点群中心的算法

    一般来说,求点群中心点的算法你可以很简的使用各个点的X/Y坐标的平均值。不过,我这里想告诉大家另三个求中心点的的公式:

    1)Minkowski Distance 公式 —— λ 可以随意取值,可以是负数,也可以是正数,或是无穷大。

     

    2)Euclidean Distance 公式 —— 也就是第一个公式 λ=2 的情况

     

    3)CityBlock Distance 公式 —— 也就是第一个公式 λ=1 的情况

     

    算法实现:

    View Code
    package kmeans;
    
    public class Data {
        private double mX=0;
        private double mY=0;
        private int mCluster=0;
        
        public Data()
        {
            return;
        }
        public Data(double x,double y)
        {
            this.X(x);
            this.Y(y);
            return;
        }
        public void X(double x)
        {
            this.mX=x;
            
        }
        public void Y(double y)
        {
            this.mY=y;
        }
        public double X()
        {
            return this.mX;
        }
        public double Y()
        {
            return this.mY;
        }
        public void cluster(int clusterNumber)
        {
            this.mCluster=clusterNumber;
            
        }
        public int  cluster()
        {
            return this.mCluster;
        }
    
    }
    View Code
    package kmeans;
    
    public class Centroid {
        private double mX = 0.0;
        private double mY = 0.0;
        
        public Centroid()
        {
            return;
        }
        
        public Centroid(double newX, double newY)
        {
            this.mX = newX;
            this.mY = newY;
            return;
        }
        
        public void X(double newX)
        {
            this.mX = newX;
            return;
        }
        
        public double X()
        {
            return this.mX;
        }
        
        public void Y(double newY)
        {
            this.mY = newY;
            return;
        }
        
        public double Y()
        {
            return this.mY;
        }
    
    }
    package kmeans;
    
    import java.util.ArrayList;
    
    public class KMeans {
        public static final int NUM_CLUSTERS=2;//TOTAL CLUSTERS
        public static final int TOTAL_DATA=7;//total data points
        public  static final double SAMPLES[][]=new double[][]{
            {1.0, 1.0}, 
            {1.5, 2.0}, 
            {3.0, 4.0}, 
            {5.0, 7.0}, 
            {3.5, 5.0}, 
            {4.5, 5.0}, 
            {3.5, 4.5}
        };
        public ArrayList<Data> dataSet=new ArrayList<Data>();
        public ArrayList<Centroid> centroids=new ArrayList<Centroid>();
        
        public void init()
        {
            System.out.println("centroids initialized at:");
            centroids.add(new Centroid(1.0,1.0));//lowest set
            centroids.add(new Centroid(5.0, 7.0)); // highest set.
            System.out.println("  ("+centroids.get(0).X()+", " + centroids.get(0).Y() + ")");
            System.out.println("     (" + centroids.get(1).X() + ", " + centroids.get(1).Y() + ")");
            System.out.print("\n");
        }
        public void kMeanCluster()
        {
            final double bigNumber=Math.pow(10,10); //// some big number that's sure to be larger than our data range.
            double minimum=bigNumber;// // The minimum value to beat. 
            double distance=0.0;// // The current minimum value.
            
            int sampleNumber=0;
            int cluster=0;
            boolean isStillMoving=true;
            Data newData=null;
            
             // Add in new data, one at a time, recalculating centroids with each new one. 
            while(dataSet.size()<TOTAL_DATA)
            {
                newData=new Data(SAMPLES[sampleNumber][0], SAMPLES[sampleNumber][1]);
                dataSet.add(newData);
                minimum=bigNumber;
                for(int i=0;i<NUM_CLUSTERS;i++)
                {
                    distance=dist(newData, centroids.get(i));
                    if(distance<minimum)
                    {
                        minimum=distance;
                        cluster=i;
                        
                    }
                }
                newData.cluster(cluster);
                
                //calculate new centroids
                for(int i=0;i<NUM_CLUSTERS;i++)
                {
                    int totalX=0;
                    int totalY=0;
                    int totalInCluster=0;
                    for(int j=0;j<dataSet.size();j++)
                    {
                        if(dataSet.get(j).cluster()==i)
                        {
                            totalX+=dataSet.get(j).X();
                            totalY+=dataSet.get(j).Y();
                            totalInCluster++;
                            
                        }
                    }
                    if(totalInCluster > 0)//有可能为0吗 有
                    {
                        centroids.get(i).X(totalX / totalInCluster);
                        centroids.get(i).Y(totalY / totalInCluster);
                    }
                
                }//end for(int i=0;i<NUM_CLUSTERS;i++)
                
                sampleNumber++;
            
            }//end while
            
            while(isStillMoving)
            {
                //calculate new centroids
                 for(int i = 0; i < NUM_CLUSTERS; i++)
                    {
                        int totalX = 0;
                        int totalY = 0;
                        int totalInCluster = 0;
                        for(int j = 0; j < dataSet.size(); j++)
                        {
                            if(dataSet.get(j).cluster() == i){
                                totalX += dataSet.get(j).X();
                                totalY += dataSet.get(j).Y();
                                totalInCluster++;
                            }
                        }
                        if(totalInCluster > 0){
                            centroids.get(i).X(totalX / totalInCluster);
                            centroids.get(i).Y(totalY / totalInCluster);
                        }
                    }
                 // Assign all data to the new centroids
                    isStillMoving = false;
                    
                 for(int i=0;i<dataSet.size();i++)
                 {
                     Data tempData=dataSet.get(i);
                     minimum=bigNumber;
                     for(int j=0;j<NUM_CLUSTERS;j++)
                     {
                         distance=dist(tempData,centroids.get(j));
                         if(distance<minimum)
                         {
                             minimum=distance;
                             cluster=j;
                         }
                     }
                     tempData.cluster(cluster);
                     if(tempData.cluster()!=cluster)
                     {
                         tempData.cluster(cluster);
                         isStillMoving=true;
                         
                     }
                 }
                 
                 
            }
         }//end function
                
          /**
         * // Calculate Euclidean distance.
         * @param d - Data object.
         * @param c - Centroid object.
         * @return - double value.
         */
        private static double dist(Data d, Centroid c)
        {
            return Math.sqrt(Math.pow((c.Y() - d.Y()), 2) + Math.pow((c.X() - d.X()), 2));
        }
            
            
            
    }
    View Code
    package kmeans;
    
    public class test {
    
        /**
         * @param args
         */
        public static void main(String[] args) {
            // TODO Auto-generated method stub
    
            KMeans k=new KMeans();
            k.init();
            k.kMeanCluster();
            //print out clustering results
            for(int i=0;i<KMeans.NUM_CLUSTERS;i++)
            {
                 System.out.println("Cluster " + i + " includes:");
                 for(int j=0;j<KMeans.TOTAL_DATA;j++)
                 {
                     if(k.dataSet.get(j).cluster()==i)
                     {
                         System.out.println(k.dataSet.get(j).X() + ", " +k.dataSet.get(j).Y() );
                         
                     }
                 }
                 System.out.println();
            }
            
            
             System.out.println("Centroids finalized at:");
                for(int i = 0; i < KMeans.NUM_CLUSTERS; i++)
                {
                    System.out.println("     (" + k.centroids.get(i).X() + ", " + k.centroids.get(i).Y());
                }
                System.out.print("\n");
                return;
        }
    
    }

    ref:http://mnemstudio.org/clustering-k-means-example-1.htm (neat)

    http://blog.jobbole.com/23157/

  • 相关阅读:
    线程与进程
    Java集合框架体系JCF
    Java异常
    抽象,接口和Object类
    Java三大特性
    面向对象
    数组
    Java 控制结构与方法
    数据类型与变量
    Java基础之入门
  • 原文地址:https://www.cnblogs.com/youxin/p/3024596.html
Copyright © 2011-2022 走看看