zoukankan      html  css  js  c++  java
  • 聚类算法——Kmeans(下)

      K-means的源码实现

      一般情况下,我们通过C++/Matlab/Python等语言进行实现K-means算法,结合近期我刚刚学的C++,先从C++实现谈起,C++里面我们一般采用的是OpenCV库中写好的K-means函数,即cvKmeans2,首先来看函数原型:
      从OpenCV manual看到的是:
    int cvKMeans2(const CvArr* samples, int nclusters,
            CvArr* labels, CvTermCriteria termcrit,
            int attempts=1, CvRNG* rng=0,int flags=0,
            CvArr* centers=0,double* compactness=0);
    由于除去已经确定的参数,我们自己需要输入的为:
    void cvKMeans2( 
        const CvArr* samples, //输入样本的浮点矩阵,每个样本一行。 
        int cluster_count,  //所给定的聚类数目 
         * labels,    //输出整数向量:每个样本对应的类别标识 
         CvTermCriteria termcrit //指定聚类的最大迭代次数和/或精度(两次迭代引起的聚类中心的移动距离)
     ); 
    其使用例程为:

     1 #ifdef _CH_
    2 #pragma package <opencv>
    3 #endif
    4
    5 #define CV_NO_BACKWARD_COMPATIBILITY
    6
    7 #ifndef _EiC
    8 #include "cv.h"
    9 #include "highgui.h"
    10 #include <stdio.h>
    11 #endif
    12
    13 int main( int argc, char** argv )
    14 {
    15 #define MAX_CLUSTERS 5 //设置类别的颜色,个数(《=5
    16 CvScalar color_tab[MAX_CLUSTERS];
    17 IplImage* img = cvCreateImage( cvSize( 500, 500 ), 8, 3 );
    18 CvRNG rng = cvRNG(-1);
    19 CvPoint ipt;
    20
    21 color_tab[0] = CV_RGB(255,0,0);
    22 color_tab[1] = CV_RGB(0,255,0);
    23 color_tab[2] = CV_RGB(100,100,255);
    24 color_tab[3] = CV_RGB(255,0,255);
    25 color_tab[4] = CV_RGB(255,255,0);
    26
    27 cvNamedWindow( "clusters", 1 );
    28
    29 for(;;)
    30 {
    31 char key;
    32 int k, cluster_count = cvRandInt(&rng)%MAX_CLUSTERS + 1;
    33 int i, sample_count = cvRandInt(&rng)%1000 + 1;
    34 CvMat* points = cvCreateMat( sample_count, 1, CV_32FC2 );
    35 CvMat* clusters = cvCreateMat( sample_count, 1, CV_32SC1 );
    36 cluster_count = MIN(cluster_count, sample_count);
    37
    38 /** generate random sample from multigaussian distribution */
    39 for( k = 0; k < cluster_count; k++ )
    40 {
    41 CvPoint center;
    42 CvMat point_chunk;
    43 center.x = cvRandInt(&rng)%img->width;
    44 center.y = cvRandInt(&rng)%img->height;
    45 cvGetRows( points, &point_chunk, k*sample_count/cluster_count,
    46 k == cluster_count - 1 ? sample_count :
    47 (k+1)*sample_count/cluster_count, 1 );
    48
    49 cvRandArr( &rng, &point_chunk, CV_RAND_NORMAL,
    50 cvScalar(center.x,center.y,0,0),
    51 cvScalar(img->width*0.1,img->height*0.1,0,0));
    52 }
    53
    54 /** shuffle samples */
    55 for( i = 0; i < sample_count/2; i++ )
    56 {
    57 CvPoint2D32f* pt1 = (CvPoint2D32f*)points->data.fl + cvRandInt(&rng)%sample_count;
    58 CvPoint2D32f* pt2 = (CvPoint2D32f*)points->data.fl + cvRandInt(&rng)%sample_count;
    59 CvPoint2D32f temp;
    60 CV_SWAP( *pt1, *pt2, temp );
    61 }
    62
    63 printf( "iterations=%d\n", cvKMeans2( points, cluster_count, clusters,
    64 cvTermCriteria( CV_TERMCRIT_EPS+CV_TERMCRIT_ITER, 10, 1.0 ),
    65 5, 0, 0, 0, 0 ));
    66
    67 cvZero( img );
    68
    69 for( i = 0; i < sample_count; i++ )
    70 {
    71 int cluster_idx = clusters->data.i[i];
    72 ipt.x = (int)points->data.fl[i*2];
    73 ipt.y = (int)points->data.fl[i*2+1];
    74 cvCircle( img, ipt, 2, color_tab[cluster_idx], CV_FILLED, CV_AA, 0 );
    75 }
    76
    77 cvReleaseMat( &points );
    78 cvReleaseMat( &clusters );
    79
    80 cvShowImage( "clusters", img );
    81
    82 key = (char) cvWaitKey(0);
    83 if( key == 27 || key == 'q' || key == 'Q' ) // 'ESC'
    84 break;
    85 }
    86
    87 cvDestroyWindow( "clusters" );
    88 return 0;
    89 }
    90
    91 #ifdef _EiC
    92 main(1,"kmeans.c");
    93 #endif

      至于cvKmeans2函数的具体实现细节,可参见OpenCV源码

      下面是Python的实现代码(网上所找):

     1  #!/usr/bin/python
    2
    3 from __future__ import with_statement
    4 import cPickle as pickle
    5 from matplotlib import pyplot
    6 from numpy import zeros, array, tile
    7 from scipy.linalg import norm
    8 import numpy.matlib as ml
    9 import random
    10
    11 def kmeans(X, k, observer=None, threshold=1e-15, maxiter=300):
    12 N = len(X)
    13 labels = zeros(N, dtype=int)
    14 centers = array(random.sample(X, k))
    15 iter = 0
    16
    17 def calc_J():
    18 sum = 0
    19 for i in xrange(N):
    20 sum += norm(X[i]-centers[labels[i]])
    21 return sum
    22
    23 def distmat(X, Y):
    24 n = len(X)
    25 m = len(Y)
    26 xx = ml.sum(X*X, axis=1)
    27 yy = ml.sum(Y*Y, axis=1)
    28 xy = ml.dot(X, Y.T)
    29
    30 return tile(xx, (m, 1)).T+tile(yy, (n, 1)) - 2*xy
    31
    32 Jprev = calc_J()
    33 while True:
    34 # notify the observer
    35 if observer is not None:
    36 observer(iter, labels, centers)
    37
    38 # calculate distance from x to each center
    39 # distance_matrix is only available in scipy newer than 0.7
    40 # dist = distance_matrix(X, centers)
    41 dist = distmat(X, centers)
    42 # assign x to nearst center
    43 labels = dist.argmin(axis=1)
    44 # re-calculate each center
    45 for j in range(k):
    46 idx_j = (labels == j).nonzero()
    47 centers[j] = X[idx_j].mean(axis=0)
    48
    49 J = calc_J()
    50 iter += 1
    51
    52 if Jprev-J < threshold:
    53 break
    54 Jprev = J
    55 if iter >= maxiter:
    56 break
    57
    58 # final notification
    59 if observer is not None:
    60 observer(iter, labels, centers)
    61
    62 if __name__ == '__main__':
    63 # load previously generated points
    64 with open('cluster.pkl') as inf:
    65 samples = pickle.load(inf)
    66 N = 0
    67 for smp in samples:
    68 N += len(smp[0])
    69 X = zeros((N, 2))
    70 idxfrm = 0
    71 for i in range(len(samples)):
    72 idxto = idxfrm + len(samples[i][0])
    73 X[idxfrm:idxto, 0] = samples[i][0]
    74 X[idxfrm:idxto, 1] = samples[i][1]
    75 idxfrm = idxto
    76
    77 def observer(iter, labels, centers):
    78 print "iter %d." % iter
    79 colors = array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
    80 pyplot.plot(hold=False) # clear previous plot
    81 pyplot.hold(True)
    82
    83 # draw points
    84 data_colors=[colors[lbl] for lbl in labels]
    85 pyplot.scatter(X[:, 0], X[:, 1], c=data_colors, alpha=0.5)
    86 # draw centers
    87 pyplot.scatter(centers[:, 0], centers[:, 1], s=200, c=colors)
    88
    89 pyplot.savefig('kmeans/iter_%02d.png' % iter, format='png')
    90
    91 kmeans(X, 3, observer=observer)

      matlab的kmeans实现代码可直接参照其kmeans(X,k)函数的实现源码。

  • 相关阅读:
    分页参数处理逻辑的最佳实践
    浅谈软件界面设计原则
    Django 页面缓存的cache_key是如何生成的
    mvn 命令上传 jar 包到 nexus 私仓
    知 识 收 录
    JavaScript 使用Map对象
    windows bat脚本守护java进程
    ubuntu java启动shell脚本
    Linux cron定时任务启动jar程序
    ubuntu java调用海康sdk报错Unable to load library '/home/bjlthy/HCNetSDK/libPlayCtrl.so'
  • 原文地址:https://www.cnblogs.com/moondark/p/2385870.html
Copyright © 2011-2022 走看看