zoukankan      html  css  js  c++  java
  • 机器学习中K-means聚类算法原理及C语言实现

    本人以前主要focus在传统音频的软件开发,接触到的算法主要是音频信号处理相关的,如各种编解码算法和回声消除算法等。最近切到语音识别上,接触到的算法就变成了各种机器学习算法,如GMM等。K-means作为其中比较简单的一种肯定是要好好掌握的。今天就讲讲K-means的基本原理和代码实现。其中基本原理简述(主要是因为:1,K-means比较简单;2,网上有很多讲K-means基本原理的),重点放在代码实现上。

    1, K-means基本原理

    K均值(K-means)聚类算法是无监督聚类(聚类(clustering)是将数据集中的样本划分为若干个通常是不相交的子集,每个子集称为一个“簇(cluster)”)算法中的一种,也是最常用的聚类算法。K表示类别数,Means表示均值。K-means主要思想是在给定K值和若干样本(点)的情况下,把每个样本(点)分到离其最近的类簇中心点所代表的类簇中,所有点分配完毕之后,根据一个类簇内的所有点重新计算该类簇的中心点(取平均值),然后再迭代的进行分配点和更新类簇中心点的步骤,直至类簇中心点的变化很小,或者达到指定的迭代次数。

     K-means算法流程如下:

    (a)随机选取K个初始cluster center

    (b)分别计算所有样本到这K个cluster center的距离

    (c)如果样本离cluster center Ci最近,那么这个样本属于Ci点簇;如果到多个cluster center的距离相等,则可划分到任意簇中

    (d)按距离对所有样本分完簇之后,计算每个簇的均值(最简单的方法就是求样本每个维度的平均值),作为新的cluster center

    (e)重复(b)(c)(d)直到新的cluster center和上轮cluster center变化很小或者达到指定的迭代次数,算法结束

    2, 算法实现

    我主要偏底层开发,最熟悉语言是C,所以代码是用C语言来实现的。在二维平面上有一些点,大意如下图,

    用K-means算法对其分类,其中类的个数(即K值)和点的个数人为指定。具体的代码如下:

    #include<stdio.h>
    #include<stdlib.h>
    #include<string.h>
    #include<math.h>

    #define MAX_ROUNDS 100    //最大允许的聚类次数

    //“点”的结构体  
    typedef struct Point{
      int x_value;           //用于存放点在X轴上的值
      int y_value;           //用于存放点在Y轴上的值
      int cluster_id;        //用于存放该点所属的cluster id
    }Point;
    Point* data;
     
    //cluster center的结构体
    typedef struct ClusterCenter{
      double x_value;
      double y_value;
      int cluster_id;
    }ClusterCenter;
    ClusterCenter* cluster_center;

    //计算cluster center的结构体
    typedef struct CenterCalc{
      double x_value;
      double y_value;
    }CenterCalc;
    CenterCalc *center_calc;
     
    int is_continue;                               //kmeans 运算是否继续
    int* cluster_center_init_index;        //记录每个cluster center最初用的是哪个“点”
    double* distance_from_center;      //记录一个“点”到所有cluster center的距离
    int* data_size_per_cluster;            //每个cluster点的个数
    int data_size_total;                        //设定点的个数
    char filename[200];                       //要读取的点的数据的文件名
    int cluster_count;                          //设定的cluster的个数
     
    void memoryAlloc();
    void memoryFree();
    void readDataFromFile();
    void initialCluster();
    void calcDistance2OneCenter(int pointID, int centerID);
    void calcDistance2AllCenters(int pointID);
    void partition4OnePoint(int pointID);
    void partition4AllPointOneCluster();
    void calcClusterCenter();
    void kmeans();
    void compareNewOldClusterCenter(CenterCalc* center_calc);
     
    int main(int argc, char* argv[])
    {
        if( argc != 4 )
        {
            printf("This application needs 3 parameters to run:"
                " the 1st is the size of data set,"
                " the 2nd is the file name that contains data"
                " the 3rd indicates the cluster_count"
                " ");
            exit(1);
        }

        data_size_total = atoi(argv[1]);
        strcat(filename, argv[2]);
        cluster_count = atoi(argv[3]);
        //1, memory alloc
        memoryAlloc();
        //2, read point data from file
        readDataFromFile();
        //3, initial cluster
        initialCluster();
        //4, run k-means
        kmeans();
        //5, memory free & end
        memoryFree();
        
        return 0;
    }

    void memoryAlloc()
    {
      data = (Point*)malloc(sizeof(struct Point) * (data_size_total));
      if( !data )
      {
        printf("malloc error:data!");
        exit(1);
      }
      cluster_center_init_index = (int*)malloc(sizeof(int) * (cluster_count));
      if( !cluster_center_init_index )
      {
        printf("malloc error:cluster_center! ");
        exit(1);
      }
      distance_from_center = (double*)malloc(sizeof(double) * (cluster_count));
      if( !distance_from_center )
      {
        printf("malloc error: distance_from_center! ");
        exit(1);
      }
      cluster_center = (ClusterCenter*)malloc(sizeof(struct ClusterCenter) * (cluster_count));
      if( !cluster_center )
      {
        printf("malloc cluster center new error! ");
        exit(1);
      }

      center_calc = (CenterCalc*)malloc(sizeof(CenterCalc) * cluster_count);
      if( !center_calc )
      {
        printf("malloc error: center_calc! ");
        exit(1);
      }

      data_size_per_cluster = (int*)malloc(sizeof(int) * (cluster_count));
      if( !data_size_per_cluster )
      {
        printf("malloc error: data_size_per_cluster ");
        exit(1);
      }
     
    }

    void memoryFree()
    {
      free(data);
      data = NULL;
      free(cluster_center_init_index);
      cluster_center_init_index = NULL;
      free(distance_from_center);
      distance_from_center = NULL;
      free(cluster_center);
      cluster_center = NULL;
      free(center_calc);
      center_calc = NULL;
      free(data_size_per_cluster);
      data_size_per_cluster = NULL;
    }

    //从文件中读入每个点的x和y值
    void readDataFromFile()
    {
      int i;
      FILE* fread;
     
      if( NULL == (fread = fopen(filename, "r")))
      {
        printf("open file(%s) error! ", filename);
        exit(1);
      }

      for( i = 0; i < data_size_total; i++ )
      {
        if( 2 != fscanf(fread, "%d %d ", &data[i].x_value, &data[i].y_value))
        {
          printf("fscanf error: %d ", i);
        }
        data[i].cluster_id = -1;    //初始时每个点所属的cluster id均置为-1

        printf("After reading, point index:%d, X:%d, Y:%d, cluster_id:%d ", i, data[i].x_value, data[i].y_value, data[i].cluster_id);
      }
    }
     

    //根据传入的cluster_count来随机的选择一个点作为 一个cluster的center  
    void initialCluster()
    {
      int i,j;
      int random;
        
      //产生初始化的cluster_count个聚类  
      for( i = 0; i < cluster_count; i++ )
      {
        cluster_center_init_index[i] = -1;
      }
      //随机选择一个点作为每个cluster的center(不重复)
      for( i = 0; i < cluster_count; i++ )
      {
        Reselect:
            random = rand() % (data_size_total - 1);
            for(j = 0; j < i; j++) {
                if(random == cluster_center_init_index[j])
                    goto Reselect;
            }

        cluster_center_init_index[i] = random;
        printf("cluster_id: %d, located in point index:%d ", i, random);  
      }
      //将随机选择的点作为center,同时这个点的cluster id也就确定了
      for( i = 0; i < cluster_count; i++ )
      {
        cluster_center[i].x_value = data[cluster_center_init_index[i]].x_value;
        cluster_center[i].y_value = data[cluster_center_init_index[i]].y_value;
        cluster_center[i].cluster_id = i;
        data[cluster_center_init_index[i]].cluster_id = i;

        printf("cluster_id:%d, index:%d, x_value:%f, y_value:%f ", cluster_center[i].cluster_id, cluster_center_init_index[i], cluster_center[i].x_value, cluster_center[i].y_value);
      }
    }
     

    //计算一个点到一个cluster center的distance
    void calcDistance2OneCenter(int point_id,int center_id)
    {
      distance_from_center[center_id] = sqrt( (data[point_id].x_value-cluster_center[center_id].x_value)*(double)(data[point_id].x_value-cluster_center[center_id].x_value) + (double)(data[point_id].y_value-cluster_center[center_id].y_value) *              (data[point_id].y_value-cluster_center[center_id].y_value) );
    }
     
    //计算一个点到每个cluster center的distance
    void calcDistance2AllCenters(int point_id)
    {
      int i;
      for( i = 0; i < cluster_count; i++ )
      {
        calcDistance2OneCenter(point_id, i);
      }
    }
     
    //确定一个点属于哪一个cluster center(取距离最小的)
    void partition4OnePoint(int point_id)
    {
      int i;
      int min_index = 0;
      double min_value = distance_from_center[0];
      for( i = 0; i < cluster_count; i++ )
      {
        if( distance_from_center[i] < min_value )
        {
          min_value = distance_from_center[i];
          min_index = i;
        }
      }
     
      data[point_id].cluster_id = cluster_center[min_index].cluster_id;
    }

    //在一轮的聚类中得到所有的point所属于的cluster center
    void partition4AllPointOneCluster()
    {
      int i;
      for( i = 0; i < data_size_total; i++ )
      {
        if( data[i].cluster_id != -1 )  //这个点就是center,不需要计算
          continue;
        else
        {
          calcDistance2AllCenters(i);  //计算第i个点到所有center的distance
          partition4OnePoint(i);          //根据distance对第i个点进行partition
        }
      }
    }

    //重新计算新的cluster center
    void calcClusterCenter()
    {
      int i;

      memset(center_calc, 0, sizeof(CenterCalc) * cluster_count);
      memset(data_size_per_cluster, 0, sizeof(int) * cluster_count);
      //分别对每个cluster内的每个点的X和Y求和,并计每个cluster内点的个数
      for( i = 0; i < data_size_total; i++ )
      {
        center_calc[data[i].cluster_id].x_value += data[i].x_value;
        center_calc[data[i].cluster_id].y_value += data[i].y_value;
        data_size_per_cluster[data[i].cluster_id]++;
      }
      //计算每个cluster内点的X和Y的均值作为center
      for( i = 0; i < cluster_count; i++ )
      {
         if(data_size_per_cluster[i] != 0) {
            center_calc[i].x_value = center_calc[i].x_value/ (double)(data_size_per_cluster[i]);
            center_calc[i].y_value = center_calc[i].y_value/ (double)(data_size_per_cluster[i]);

            printf(" cluster %d point cnt:%d ", i, data_size_per_cluster[i]);
            printf(" cluster %d center: X:%f, Y:%f ", i, center_calc[i].x_value, center_calc[i].y_value);
        }
        else
              printf(" cluster %d count is zero ", i);
      }
     
      //比较新的和旧的cluster center值的差别。如果是相等的,则停止K-means算法。
      compareNewOldClusterCenter(center_calc);
     
      //将新的cluster center的值放入cluster_center结构体中
      for( i = 0; i < cluster_count; i++ )
      {
        cluster_center[i].x_value = center_calc[i].x_value;
        cluster_center[i].y_value = center_calc[i].y_value;
        cluster_center[i].cluster_id = i;
      }

      //在重新计算了新的cluster center之后,要重新来为每一个Point进行聚类,所以data中用于表示聚类ID的cluster_id要都重新置为-1。
      for( i = 0; i < data_size_total; i++ )
      {
        data[i].cluster_id = -1;
      }
    }
     
    //比较新旧的cluster center的值,完全一样表示聚类完成
    void compareNewOldClusterCenter(CenterCalc* center_calc)
    {
      int i;
      is_continue = 0;       //等于0表示不要继续,1表示要继续
      for( i = 0; i < cluster_count; i++ )
      {
        if( center_calc[i].x_value != cluster_center[i].x_value || center_calc[i].y_value != cluster_center[i].y_value)
        {
          is_continue = 1;
          break;
        }
      }
    }
     
    //K-means算法
    void kmeans()
    {
      int rounds;
      for( rounds = 0; rounds < MAX_ROUNDS; rounds++ )
      {
        printf(" Rounds : %d             ", rounds+1);
        partition4AllPointOneCluster();
        calcClusterCenter();
        if( 0 == is_continue )
        {
           printf(" after %d rounds, the classification is ok and can stop. ", rounds+1);
           break;  
        }
      }
    }

    编译后生成可执行文件kmeans,输入的文件里共有6个点,分别为(0, 0), (4, 4), (4, 5), (0, 1), (3, 6) ,(4, 9),要求分成两类。运行可执行程序后得到结果如下:

    $ ./kmeans 6 data 2
    After reading, point index:0, X:0, Y:0, cluster_id:-1
    After reading, point index:1, X:4, Y:4, cluster_id:-1
    After reading, point index:2, X:4, Y:5, cluster_id:-1
    After reading, point index:3, X:0, Y:1, cluster_id:-1
    After reading, point index:4, X:3, Y:6, cluster_id:-1
    After reading, point index:5, X:4, Y:9, cluster_id:-1


    cluster_id: 0, located in point index:3
    cluster_id: 1, located in point index:1
    cluster_id:0, index:3, x_value:0.000000, y_value:1.000000
    cluster_id:1, index:1, x_value:4.000000, y_value:4.000000

    Rounds : 1             
     cluster 0 point cnt:2
     cluster 0 center: X:0.000000, Y:0.500000
     cluster 1 point cnt:4
     cluster 1 center: X:3.750000, Y:6.000000

    Rounds : 2             
     cluster 0 point cnt:2
     cluster 0 center: X:0.000000, Y:0.500000
     cluster 1 point cnt:4
     cluster 1 center: X:3.750000, Y:6.000000

     after 2 rounds, the classification is ok and can stop.


    即两轮后聚类就好了,(0, 0),(0, 1)一类,(4, 4), (4, 5), (3, 6) ,(4, 9)一类。

  • 相关阅读:
    把SVN和jenkins连起来--有人提代码就能自动build!
    Redis设置认证密码 Redis使用认证密码登录 在Redis集群中使用认证密码
    JwtAuthenticationTokenFilter 实现shiro 利用 token 信息完成令牌登录
    Jenkins war deploy Shell
    jenkins 不执行部署 tomcat
    CentOS 7 设置 svn 开机启动
    502 Bad Gateway nginx/1.12.2 tomcat
    shiro 集成 JWT 自动获取token对应的用户信息
    org.apache.shiro.session.UnknownSessionException: There is no session with
    多角色分库情况下shiro开发
  • 原文地址:https://www.cnblogs.com/talkaudiodev/p/10901798.html
Copyright © 2011-2022 走看看