zoukankan      html  css  js  c++  java
  • kmeans算法python实现

    import os
    import sys
    import numpy as np
    import matplotlib.pyplot as plt
    import random
    
    # https://www.jianshu.com/p/c31b08179655
    
    def load_data(path="data.txt"):
        f = open(path, encoding='utf-8')
        data = []
        for line in f.readlines():
            try:
                line = line.strip().split()
                # print(line)
                data.append([float(line[0]), float(line[1])])
            except:
                pass
        f.close()
    
        return np.array(data)
    
    def showCluster(dataset, k, centroids, cluster_assignment):
        numSamples, dim = dataset.shape
        mark = ['or', 'ob', 'og', 'ok', '^r', '+r', 'sr', 'dr']
        # draw all samples  
        for i in range(numSamples):
            mark_idx = int(cluster_assignment[i, 0])
            plt.plot(dataset[i, 0], dataset[i, 1], mark[mark_idx])
    
        mark = ['Dr', 'Db', 'Dg', 'Dk', '^b', '+b', 'sb', 'db']
        for i in range(k):
            plt.plot(centroids[i, 0], centroids[i, 1], mark[i], markersize=12)
        plt.show()
    
    
    def cal_dist(v1, v2):
        dist = np.sqrt(np.sum((v1 - v2)*(v1 - v2)))
        return dist
    
    def select_init_centroids(dataset, k):
        print(dataset.shape)
        samples_num, dim = dataset.shape
        centroids = np.zeros((k, dim))
        cnt = 0
        selected_idxs = []
        while cnt < k:
            idx = random.randint(0, samples_num-1)
            if idx in selected_idxs:
                continue
            selected_idxs.append(idx)
            centroids[cnt] = dataset[idx]
            cnt += 1
        return centroids
    
    def kmeans(dataset, k):
        samples_num, dim = dataset.shape
        centroids = select_init_centroids(dataset, k)
        cluster_assignment = np.zeros((samples_num, 2))
        cluster_changed = True
        while cluster_changed:
            cluster_changed = False
            for i in range(samples_num):
                min_dist = 1000000.0
                min_idx = 0
                
                # step2: find the centroid who is closet
                for j in range(k):
                    dist = cal_dist(dataset[i], centroids[j])
                    if dist < min_dist:
                        min_dist = dist
                        min_idx = j
                
                # step 3: update it's cluster
                if cluster_assignment[i][0] != min_idx:
                    cluster_assignment[i][0] = min_idx
                    cluster_assignment[i][1] = min_dist
                    cluster_changed = True
    
            # step4: update centroids
            for j in range(k):
                points_in_cluster = dataset[cluster_assignment[:, 0]==j]
                centroids[j, :] = np.mean(points_in_cluster, axis=0)
    
        return centroids, cluster_assignment
    
    def main():
        dataset = load_data()
        # print(data[:10])
        centroids, cluster_assignment = kmeans(dataset, 4)
        # centroids = select_init_centroids(dataset, 4)
        print(centroids)
        showCluster(dataset, 4, centroids, cluster_assignment)
    
    main()

    最终效果如下:

     附数据:

    1.658985    4.285136  
    -3.453687   3.424321  
    4.838138    -1.151539  
    -5.379713   -3.362104  
    0.972564    2.924086  
    -3.567919   1.531611  
    0.450614    -3.302219  
    -3.487105   -1.724432  
    2.668759    1.594842  
    -3.156485   3.191137  
    3.165506    -3.999838  
    -2.786837   -3.099354  
    4.208187    2.984927  
    -2.123337   2.943366  
    0.704199    -0.479481  
    -0.392370   -3.963704  
    2.831667    1.574018  
    -0.790153   3.343144  
    2.943496    -3.357075  
    -3.195883   -2.283926  
    2.336445    2.875106  
    -1.786345   2.554248  
    2.190101    -1.906020  
    -3.403367   -2.778288  
    1.778124    3.880832  
    -1.688346   2.230267  
    2.592976    -2.054368  
    -4.007257   -3.207066  
    2.257734    3.387564  
    -2.679011   0.785119  
    0.939512    -4.023563  
    -3.674424   -2.261084  
    2.046259    2.735279  
    -3.189470   1.780269  
    4.372646    -0.822248  
    -2.579316   -3.497576  
    1.889034    5.190400  
    -0.798747   2.185588  
    2.836520    -2.658556  
    -3.837877   -3.253815  
    2.096701    3.886007  
    -2.709034   2.923887  
    3.367037    -3.184789  
    -2.121479   -4.232586  
    2.329546    3.179764  
    -3.284816   3.273099  
    3.091414    -3.815232  
    -3.762093   -2.432191  
    3.542056    2.778832  
    -1.736822   4.241041  
    2.127073    -2.983680  
    -4.323818   -3.938116  
    3.792121    5.135768  
    -4.786473   3.358547  
    2.624081    -3.260715  
    -4.009299   -2.978115  
    2.493525    1.963710  
    -2.513661   2.642162  
    1.864375    -3.176309  
    -3.171184   -3.572452  
    2.894220    2.489128  
    -2.562539   2.884438  
    3.491078    -3.947487  
    -2.565729   -2.012114  
    3.332948    3.983102  
    -1.616805   3.573188  
    2.280615    -2.559444  
    -2.651229   -3.103198  
    2.321395    3.154987  
    -1.685703   2.939697  
    3.031012    -3.620252  
    -4.599622   -2.185829  
    4.196223    1.126677  
    -2.133863   3.093686  
    4.668892    -2.562705  
    -2.793241   -2.149706  
    2.884105    3.043438  
    -2.967647   2.848696  
    4.479332    -1.764772  
    -4.905566   -2.911070 
  • 相关阅读:
    分别使用Nginx反向代理和Haproxy调度器实现web服务器负载均衡
    CentOS7.4 源码编译安装LNMP
    LVS-DR+keepalived高可用群集
    Weex 和 Web 平台的差异
    Weex 和 Vue 2.x 的语法差异
    如何将原有 Weex 项目改造成 Vue 版本
    Vue 2.x 在 Weex 和 Web 中的差异
    使用 Vuex 和 vue-router
    使用 Vue 开发 Weex 页面
    weex快速上手
  • 原文地址:https://www.cnblogs.com/hejunlin1992/p/12887323.html
Copyright © 2011-2022 走看看