zoukankan      html  css  js  c++  java
  • K-means

    首先要来了解的一个概念就是聚类,简单地说就是把相似的东西分到一组,同 Classification (分类)不同,对于一个 classifier ,通常需要你告诉它“这个东西被分为某某类”这样一些例子,理想情况下,一个 classifier 会从它得到的训练集中进行“学习”,从而具备对未知数据进行分类的能力,这种提供训练数据的过程通常叫做 supervised learning (监督学习),而在聚类的时候,我们并不关心某一类是什么,我们需要实现的目标只是把相似的东西聚到一起,因此,一个聚类算法通常只需要知道如何计算相似 度就可以开始工作了,因此 clustering 通常并不需要使用训练数据进行学习,这在 Machine Learning 中被称作 unsupervised learning (无监督学习)。

      我们经常接触到的聚类分析,一般都是数值聚类,一种常见的做法是同时提取 N 种特征,将它们放在一起组成一个 N 维向量,从而得到一个从原始数据集合到 N 维向量空间的映射——你总是需要显式地或者隐式地完成这样一个过程,然后基于某种规则进行分类,在该规则下,同组分类具有最大的相似性。

      假设我们提取到原始数据的集合为(x1x2, …, xn),并且每个xi为d维的向量,K-means聚类的目的就是,在给定分类组数k(k ≤ n)值的条件下,将原始数据分成k类 
    S = {S1S2, …, Sk},在数值模型上,即对以下表达式求最小值:
    underset{mathbf{S}} {operatorname{arg\,min}} sum_{i=1}^{k} sum_{mathbf x_j in S_i} left| mathbf x_j - oldsymbolmu_i 
ight|^2
    这里μi 表示分类S的平均值。

      那么在计算机编程中,其又是如何实现的呢?其算法步骤一般如下:

    1、从D中随机取k个元素,作为k个簇的各自的中心。

    2、分别计算剩下的元素到k个簇中心的相异度,将这些元素分别划归到相异度最低的簇。

    3、根据聚类结果,重新计算k个簇各自的中心,计算方法是取簇中所有元素各自维度的算术平均数。

    4、将D中全部元素按照新的中心重新聚类。

    5、重复第4步,直到聚类结果不再变化。

    6、将结果输出。

      用数学表达式来说,

    设我们一共有 N 个数据点需要分为 K 个 cluster ,k-means 要做的就是最小化

    displaystyle J = sum_{n=1}^Nsum_{k=1}^K r_{nk} |x_n-mu_k|^2

    这个函数,其中 r_{nk} 在数据点 n 被归类到 cluster k 的时候为 1 ,否则为 0 。直接寻找 r_{nk} 和 mu_k 来最小化 J 并不容易,不过我们可以采取迭代的办法:先固定 mu_k ,选择最优的 r_{nk} ,很容易看出,只要将数据点归类到离他最近的那个中心就能保证 J 最小。下一步则固定 r_{nk},再求最优的 mu_k。将 J 对 mu_k 求导并令导数等于零,很容易得到 J 最小的时候 mu_k 应该满足:

    displaystyle mu_k=frac{sum_n r_{nk}x_n}{sum_n r_{nk}}

    亦即 mu_k 的值应当是所有 cluster k 中的数据点的平均值。由于每一次迭代都是取到 J 的最小值,因此 J 只会不断地减小(或者不变),而不会增加,这保证了 k-means 最终会到达一个极小值。虽然 k-means 并不能保证总是能得到全局最优解,但是对于这样的问题,像 k-means 这种复杂度的算法,这样的结果已经是很不错的了。

    首先 3 个中心点被随机初始化,所有的数据点都还没有进行聚类,默认全部都标记为红色,如下图所示:

    iter_00

    然后进入第一次迭代:按照初始的中心点位置为每个数据点着上颜色,重新计算 3 个中心点,结果如下图所示:

    iter_01

    可以看到,由于初始的中心点是随机选的,这样得出来的结果并不是很好,接下来是下一次迭代的结果:

    iter_02

    可以看到大致形状已经出来了。再经过两次迭代之后,基本上就收敛了,最终结果如下:

    iter_04

    不过正如前面所说的那样 k-means 也并不是万能的,虽然许多时候都能收敛到一个比较好的结果,但是也有运气不好的时候会收敛到一个让人不满意的局部最优解,例如选用下面这几个初始中心点:

    iter_00_bad

    最终会收敛到这样的结果:

    iter_03_bad

      整体来讲,K-means算法的聚类思想比较简单明了,并且聚类效果也还算可以,算是一种简单高效应用广泛的 clustering 方法,接下来,我将讨论其代码实现过程。

       KMeans算法本身思想比较简单,但是合理的确定K值和K个初始类簇中心点对于聚类效果的好坏有很大的影响。

     K-means的源码实现

      一般情况下,我们通过C++/Matlab/Python等语言进行实现K-means算法,结合近期我刚刚学的C++,先从C++实现谈起,C++里面我们一般采用的是OpenCV库中写好的K-means函数,即cvKmeans2,首先来看函数原型:
      从OpenCV manual看到的是:
    int cvKMeans2(const CvArr* samples, int nclusters,
            CvArr* labels, CvTermCriteria termcrit,
            int attempts=1, CvRNG* rng=0,int flags=0, 
            CvArr* centers=0,double* compactness=0);
    由于除去已经确定的参数,我们自己需要输入的为:
    void cvKMeans2( 
        const CvArr* samples, //输入样本的浮点矩阵,每个样本一行。 
        int cluster_count,  //所给定的聚类数目 
         * labels,    //输出整数向量:每个样本对应的类别标识 
         CvTermCriteria termcrit //指定聚类的最大迭代次数和/或精度(两次迭代引起的聚类中心的移动距离)
     ); 
    其使用例程为:

     1 #ifdef _CH_
    2 #pragma package <opencv>
    3 #endif
    4
    5 #define CV_NO_BACKWARD_COMPATIBILITY
    6
    7 #ifndef _EiC
    8 #include "cv.h"
    9 #include "highgui.h"
    10 #include <stdio.h>
    11 #endif
    12
    13 int main( int argc, char** argv )
    14 {
    15 #define MAX_CLUSTERS 5 //设置类别的颜色,个数(《=5)
    16 CvScalar color_tab[MAX_CLUSTERS];
    17 IplImage* img = cvCreateImage( cvSize( 500, 500 ), 8, 3 );
    18 CvRNG rng = cvRNG(-1);
    19 CvPoint ipt;
    20
    21 color_tab[0] = CV_RGB(255,0,0);
    22 color_tab[1] = CV_RGB(0,255,0);
    23 color_tab[2] = CV_RGB(100,100,255);
    24 color_tab[3] = CV_RGB(255,0,255);
    25 color_tab[4] = CV_RGB(255,255,0);
    26
    27 cvNamedWindow( "clusters", 1 );
    28
    29 for(;;)
    30 {
    31 char key;
    32 int k, cluster_count = cvRandInt(&rng)%MAX_CLUSTERS + 1;
    33 int i, sample_count = cvRandInt(&rng)%1000 + 1;
    34 CvMat* points = cvCreateMat( sample_count, 1, CV_32FC2 );
    35 CvMat* clusters = cvCreateMat( sample_count, 1, CV_32SC1 );
    36 cluster_count = MIN(cluster_count, sample_count);
    37
    38 /** generate random sample from multigaussian distribution */
    39 for( k = 0; k < cluster_count; k++ )
    40 {
    41 CvPoint center;
    42 CvMat point_chunk;
    43 center.x = cvRandInt(&rng)%img->width;
    44 center.y = cvRandInt(&rng)%img->height;
    45 cvGetRows( points, &point_chunk, k*sample_count/cluster_count,
    46 k == cluster_count - 1 ? sample_count :
    47 (k+1)*sample_count/cluster_count, 1 );
    48
    49 cvRandArr( &rng, &point_chunk, CV_RAND_NORMAL,
    50 cvScalar(center.x,center.y,0,0),
    51 cvScalar(img->width*0.1,img->height*0.1,0,0));
    52 }
    53
    54 /** shuffle samples */
    55 for( i = 0; i < sample_count/2; i++ )
    56 {
    57 CvPoint2D32f* pt1 = (CvPoint2D32f*)points->data.fl + cvRandInt(&rng)%sample_count;
    58 CvPoint2D32f* pt2 = (CvPoint2D32f*)points->data.fl + cvRandInt(&rng)%sample_count;
    59 CvPoint2D32f temp;
    60 CV_SWAP( *pt1, *pt2, temp );
    61 }
    62
    63 printf( "iterations=%d ", cvKMeans2( points, cluster_count, clusters,
    64 cvTermCriteria( CV_TERMCRIT_EPS+CV_TERMCRIT_ITER, 10, 1.0 ),
    65 5, 0, 0, 0, 0 ));
    66
    67 cvZero( img );
    68
    69 for( i = 0; i < sample_count; i++ )
    70 {
    71 int cluster_idx = clusters->data.i[i];
    72 ipt.x = (int)points->data.fl[i*2];
    73 ipt.y = (int)points->data.fl[i*2+1];
    74 cvCircle( img, ipt, 2, color_tab[cluster_idx], CV_FILLED, CV_AA, 0 );
    75 }
    76
    77 cvReleaseMat( &points );
    78 cvReleaseMat( &clusters );
    79
    80 cvShowImage( "clusters", img );
    81
    82 key = (char) cvWaitKey(0);
    83 if( key == 27 || key == 'q' || key == 'Q' ) // 'ESC'
    84 break;
    85 }
    86
    87 cvDestroyWindow( "clusters" );
    88 return 0;
    89 }
    90
    91 #ifdef _EiC
    92 main(1,"kmeans.c");
    93 #endif

      至于cvKmeans2函数的具体实现细节,可参见OpenCV源码

      下面是Java的实现代码:

    此代码已编译通过,数据用的是鸢尾花数据格式,其中具体数据用1,2,3代替了鸢尾花真实数据,共分为三类

    package util;
    
    import java.io.BufferedReader;
    import java.io.BufferedWriter;
    import java.io.FileReader;
    import java.io.FileWriter;
    import java.io.IOException;
    import java.util.ArrayList;
     
    //K-means算法实现
    /*
     * 聚类数目的选取以及初始节点的选取对k-means的聚类效果影响特别大
     */
     
    public class KmeansTest {
        //聚类的数目
        final static int ClassCount = 3;
        //样本数目(测试集)
        final static int InstanceNumber = 150; 
        //样本属性数目(测试)
        final static int FieldCount = 5;
       
        //设置异常点阈值参数(每一类初始的最小数目为InstanceNumber/ClassCount^t)
        final static double t = 2.0;
        //存放数据的矩阵
        private float[][] data;
       
        //每个类的均值中心(行数就是类别数)
        private float[][] classData;
       
        //噪声集合索引
        private ArrayList<Integer> noises;
       
        //存放每次变换结果的矩阵(最外层数量为类别数)
        private ArrayList<ArrayList<Integer>> result;
       
        //构造函数,初始化
        public KmeansTest()
        {
    	   //最后一位用来储存结果
    	   data = new float[InstanceNumber][FieldCount+1];
    	   classData = new float[ClassCount][FieldCount];
    	   result = new ArrayList<ArrayList<Integer>>(ClassCount);
    	   noises = new ArrayList<Integer>();
      
        }
     
       /**
        * 主函数入口
        * 测试集的文件名称为“测试集.data”,其中有1000*57大小的数据
        * 每一行为一个样本,有57个属性
        * 主要分为两个步骤
        * 1.读取数据
        * 2.进行聚类
        * 最后统计运行时间和消耗的内存
        * @param args
        */
       public static void main(String[] args) {
          // TODO Auto-generated method stub
           long startTime = System.currentTimeMillis();
           KmeansTest cluster = new KmeansTest();
           //读取数据
           cluster.readData("iris.data");
           //聚类过程
           cluster.cluster();
           //输出结果
           cluster.printResult("clusterResult.data");
           long endTime = System.currentTimeMillis();
           System.out.println("Total Time:"+ (endTime - startTime)/1000+"s");//系统运行耗时
           System.out.println("Memory Consuming:"+(float)(Runtime.getRuntime().totalMemory() -
              Runtime.getRuntime().freeMemory())/1000000 + "MB");//系统存储消耗
       }
       /*
       * 读取测试集的数据
       *
       * @param trainingFileName 测试集文件名
       */
       public void readData(String trainingFileName)
       {
           try
           {
    	      FileReader fr = new FileReader(trainingFileName);
    	      BufferedReader br = new BufferedReader(fr);
    	      //存放数据的临时变量
    	      String lineData = null;
    	      String[] splitData = null;
    	      int line = 0;
    	      //按行读取
    	      while(br.ready())//是否准备好被读
    	      {
    	          //得到原始的字符串
    	          lineData = br.readLine();
    	          splitData = lineData.split(",");
    	          //转化为数据
    	//        System.out.println("length:"+splitData.length);
    	          if(splitData.length>1)
    	          {
    	             for(int i = 0;i < splitData.length;i++)
    	             {
    	//              System.out.println(splitData[i]);
    	//              System.out.println(splitData[i].getClass());
    	                if(splitData[i].startsWith("Iris-setosa"))
    	                {
    	                   data[line][i] = (float) 1.0;//单纯的将原来的数据换成一样的数据,相当于有指导了
    	                }
    	                else if(splitData[i].startsWith("Iris-versicolor"))
    	                {
    	                   data[line][i] = (float) 2.0;
    	                }
    	                else if(splitData[i].startsWith("Iris-virginica"))
    	                {
    	                   data[line][i] = (float) 3.0;
    	                }
    	                else
    	                {   //将数据截取之后放进数组
    	                   data[line][i] = Float.parseFloat(splitData[i]);
    	                }
    	             }
    	             line++;
    	          }
    	      }
    	      System.out.println("line: "+line);
           }catch(IOException e)
           {
          e.printStackTrace();
           }
       }
       
       /*
       * 聚类过程,主要分为两步
       * 1.循环找初始点
       * 2.不断调整直到分类不再发生变化
       */
       public void cluster()
       {
           //数据归一化
           normalize();
           //标记是否需要重新找初始点
           boolean needUpdataInitials = true;
          
           //找初始点的迭代次数
           int times = 1;
           //找初始点
           while(needUpdataInitials)
           {
    	      needUpdataInitials = false;
    	      result.clear();
    	      System.out.println("Find Initials Iteration "+(times++)+"time(s)");
    	     
    	      //一次找初始点的尝试和根据初始点的分类
    	      findInitials();
    	      firstClassify();
    	     
    	      //如果某个分类的数目小于特定的阈值,则认为这个分类中的所有样本都是噪声点
    	      //需要重新找初始点
    	      for(int i = 0;i < result.size();i++)
    	      {
    	          if(result.get(i).size() < InstanceNumber/Math.pow(ClassCount,t))
    	          {
    		         needUpdataInitials = true;
    		         noises.addAll(result.get(i));
    	          }
    	      }
           }
          
           //找到合适的初始点后
           //不断的调整均值中心和分类,直到不再发生任何变化
           Adjust();
       }
      
       /*
        * 对数据进行归一化
        * 1.找每一个属性的最大值
        * 2.对某个样本的每个属性除以其最大值
        */
       public void normalize()
       {
           //找最大值(每行数据的最大值)
           float[] max = new float[FieldCount];
           for(int i = 0;i < InstanceNumber;i++)
           {
    	      for(int j = 0;j < FieldCount;j++)
    	      {
    	          if(data[i][j] > max[j])
    	          max[j] = data[i][j];
    	      }
           }
          
           //归一化
           for(int i = 0;i < InstanceNumber;i++)
           {
    	      for(int j = 0;j < FieldCount;j++)
    	      {
    	          data[i][j] = data[i][j]/max[j];
    	      }
           }
       }
      
       //关于初始向量的一次找寻尝试
       public void findInitials()
       {
           //a,b为标志距离最远的两个向量的索引
           int i,j,a,b;
           i = j = a = b = 0;
          
           //最远距离
           float maxDis = 0;
          
           //已经找到的初始点个数--------------初始点的个数和分类数的个数不一样么?
           int alreadyCls = 2;
          
           //存放已经标记为初始点的向量索引
           ArrayList<Integer> initials = new ArrayList<Integer>();
          
           //从两个开始
           for(;i < InstanceNumber;i++)
           {
    	      //噪声点
    	      if(noises.contains(i))
    	          continue;
    	      //long startTime = System.currentTimeMillis();
    	      j = i + 1;
    	      for(;j < InstanceNumber;j++)
    	      {
    	          //噪声点
    	          if(noises.contains(j))
    	        	  continue;
    	          //找出最大的距离并记录下来
    	          float newDis = calDis(data[i],data[j]);
    	          if(maxDis < newDis)
    	          {
    		         a = i;
    		         b = j;
    		         maxDis = newDis;
    	          }
    	      }
    	      //long endTime = System.currentTimeMillis();
    	      //System.out.println(i + "Vector Caculation Time:"+(endTime-startTime)+"ms");
           }
          
           //将前两个初始点记录下来
           initials.add(a);
           initials.add(b);
           classData[0] = data[a];
           classData[1] = data[b];
          
           //在结果中新建存放某样本索引的对象,并把初始点添加进去
           ArrayList<Integer> resultOne = new ArrayList<Integer>();
           ArrayList<Integer> resultTwo = new ArrayList<Integer>();
           resultOne.add(a);
           resultTwo.add(b);
           result.add(resultOne);
           result.add(resultTwo);
          
           //找到剩余的几个初始点
           while(alreadyCls < ClassCount)
           {
    	      i = j = 0;
    	      float maxMin = 0;
    	      int newClass = -1;
    	     
    	      //找最小值中的最大值
    	      for(;i < InstanceNumber;i++)
    	      {
    	          float min = 0;
    	          float newMin = 0;
    	          //找和已有类的最小值
    	          if(initials.contains(i))
    	        	  continue;
    	          //噪声点去除
    	          if(noises.contains(i))
    	        	  continue;
    	          for(j = 0;j < alreadyCls;j++)
    	          {
    		         newMin = calDis(data[i],classData[j]);
    		         if(min == 0 || newMin < min)
    		             min = newMin;
    	          }
    	         
    	          //新最小距离较大
    	          if(min > maxMin)
    	          {
    		         maxMin = min;
    		         newClass = i;
    	          }
    	      }
    	      //添加到均值集合和结果集合中
    	      //System.out.println("NewClass"+newClass);
    	      initials.add(newClass);
    	      classData[alreadyCls++] = data[newClass];
    	      ArrayList<Integer> rslt = new ArrayList<Integer>();
    	      rslt.add(newClass);
    	      result.add(rslt);
           }
       }
      
       //第一次分类
       public void firstClassify()
       {
           //根据初始向量分类
           for(int i = 0;i < InstanceNumber;i++)
           {
    	      float min = 0f;
    	      int clsId = -1;
    	      for(int j = 0;j < classData.length;j++)
    	      {
    	          //欧式距离
    	          float newMin = calDis(classData[j],data[i]);
    	          if(clsId == -1 || newMin <min)
    	          {
    		         clsId = j;
    		         min = newMin;
    	          }
    	         
    	      }
    	      //本身不再添加
    	      if(!result.get(clsId).contains(i))
    	          result.get(clsId).add(i);
    	    }
       }
       //迭代分类,直到各个类的数据不再变化
       public void Adjust()
       {
           //记录是否发生变化
           boolean change = true;
          
           //循环的次数
           int times = 1;
           while(change)
           {
    	      //复位
    	      change = false;
    	      System.out.println("Adjust Iteration "+(times++)+"time(s)");
    	                   
    	      //重新计算每个类的均值 
    	      for(int i = 0;i < ClassCount; i++){ 
    		      //原有的数据 
    		      ArrayList<Integer> cls = result.get(i); 
    		       
    		      //新的均值 
    		      float[] newMean = new float[FieldCount ]; 
    		       
    		      //计算均值 
    		      for(Integer index:cls){
    		    	  for(int j = 0;j < FieldCount ;j++) 
    		              newMean[j] += data[index][j]; 
    		      } 
    		      for(int j = 0;j < FieldCount ;j++) 
    		         newMean[j] /= cls.size(); 
    		      if(!compareMean(newMean, classData[i])){  
    		    	  classData[i] = newMean; 
    		          change = true; 
    		      } 
    	      } 
    	      //清空之前的数据 
    	      for(ArrayList<Integer> cls:result) 
    	       cls.clear(); 
    	        
    	      //重新分配 
    	      for(int i = 0;i < InstanceNumber;i++) 
    	      { 
    		       float min = 0f; 
    		       int clsId = -1; 
    		       for(int j = 0;j < classData.length;j++){ 
    		    	   float newMin = calDis(classData[j], data[i]); 
    		    	   if(clsId == -1 || newMin < min){ 
    		    		   clsId = j; 
    		    		   min = newMin; 
    		           } 
    		       } 
    		       data[i][FieldCount] = clsId; 
    		       result.get(clsId).add(i); 
    	       } 
    	                 
    	         //测试聚类效果(训练集) 
    	      //          for(int i = 0;i < ClassCount;i++){ 
    	      //              int positives = 0; 
    	      //              int negatives = 0; 
    	      //              ArrayList<Integer> cls = result.get(i); 
    	      //              for(Integer instance:cls) 
    	      //                  if (data[instance][FieldCount - 1] == 1f) 
    	      //                      positives ++; 
    	      //                  else 
    	      //                      negatives ++; 
    	      //              System.out.println(" " + i + " Positive: " + positives + " Negatives: " + negatives); 
    	      //          } 
    	      //          System.out.println(); 
           }             
       } 
              
       /**
       * 计算a样本和b样本的欧式距离作为不相似度
       * 
       * @param a     样本a
       * @param b     样本b
       * @return      欧式距离长度
       */ 
       private float calDis(float[] aVector,float[] bVector)  {
          double dis = 0;
          int i = 0;
          /*最后一个数据在训练集中为结果,所以不考虑  */
          for(;i < aVector.length;i++)
              dis += Math.pow(bVector[i] - aVector[i],2); 
          dis = Math.pow(dis, 0.5); 
          return (float)dis; 
       }
             
      /**
      * 判断两个均值向量是否相等
      * 
      * @param a 向量a
      * @param b 向量b
      * @return
      */ 
       private boolean compareMean(float[] a,float[] b) 
       {
           if(a.length != b.length)   
               return false; 
           for(int i =0;i < a.length;i++){ 
              if(a[i] > 0 &&b[i] > 0&& a[i] != b[i]){ 
                   return false; 
              }    
           } 
           return true; 
       } 
              
       /**
       * 将结果输出到一个文件中
       * 
       * @param fileName
       */ 
       public void printResult(String fileName) 
       {    
           FileWriter fw = null; 
                BufferedWriter bw = null; 
                try { 
                      fw = new FileWriter(fileName); 
                      bw = new BufferedWriter(fw); 
    	              //写入文件 
    	               for(int i = 0;i < InstanceNumber;i++) 
    	               { 
    	                  bw.write(String.valueOf(data[i][FieldCount]).substring(0, 1)); 
    	                   bw.newLine(); 
    	               } 
    	          
    	               //统计每类的数目,打印到控制台 
    	               for(int i = 0;i < ClassCount;i++) 
    	             { 
    	                     System.out.println("第" + (i+1) + "类数目: " + result.get(i).size()); 
    	             } 
                } catch (IOException e) { 
                	e.printStackTrace(); 
                } finally{ 
    	             //关闭资源 
    	             if(bw != null) 
    	                   try { 
    	                     bw.close(); 
    	                   } catch (IOException e) { 
    	                       e.printStackTrace(); 
    	                   } 
    	               if(fw != null) 
    	                    try { 
    	                        fw.close(); 
    	                    } catch (IOException e) { 
    	                         e.printStackTrace(); 
    	                    } 
                 }     
            } 
    }
    

      

      matlab的kmeans实现代码可直接参照其kmeans(X,k)函数的实现源码。

    附:K-means聚类算法

  • 相关阅读:
    usb mtp激活流程【转】
    [RK3288][Android6.0] USB OTG模式及切换【转】
    简单实用的磁带转MP3方法图解
    使用log4j的邮件功能
    hive从查询中获取数据插入到表或动态分区
    map和reduce 个数的设定 (Hive优化)经典
    Mysql ERROR 145 (HY000)
    Mysql计算时间差
    小米刷机教程和GAE for android
    Hbase 使用方法
  • 原文地址:https://www.cnblogs.com/fclbky/p/5026213.html
Copyright © 2011-2022 走看看