zoukankan      html  css  js  c++  java
  • K均值算法

    为了便于可视化,样本数据为随机生成的二维样本点。

    from matplotlib import pyplot as plt
    import numpy as np
    import random
    
    
    def kmeans(a, k):
        def randomChoose(a, k):
            # 从数组a中随机选取k个元素,返回一个list
            args = np.arange(len(a))  # 元素下标
            for i in range(k):
                x = np.random.randint(i, len(a))
                args[x], args[i] = args[i], args[x]  # 交换两个数
            return a[args[:k]]  # 返回前k个元素
    
        def tag(a, center):
            dis = np.empty((len(a), len(center)),dtype=np.float)
            for i in range(len(a)):
                for j in range(len(center)):
                    dis[i][j] = np.linalg.norm(a[i] - center[j])
            label = np.argmin(dis, axis=1)
            return label
    
        def get_center(a, label,k):
            centers = np.empty((k,a.shape[1]),dtype=np.float)
            for i in range(len(centers)):
                centers[i] = np.mean(a[label == i],axis=0)
            return centers
        centers = randomChoose(a, k)
        last_label = None
        label = tag(a, centers)
        while last_label is None or np.any(last_label != label):
            # print(centers)
            # input()
            last_label = label
            centers = get_center(a, label,k)
            label=tag(a,centers)
        return label,centers
    
    a = np.random.random((100, 2))
    print(a)
    c=['b', 'g', 'r', 'c', 'm', 'y', 'k', 'w']
    k=len(c)
    label,centers=kmeans(a,k)
    a=a.T
    fig,ax=plt.subplots(3,3)
    ax=ax.reshape(-1)
    #需要注意,如果不把全局画出来,看上去聚类效果很差,因为matplot自动缩放坐标轴
    for i in range(k):
        x,y=a[:,label==i]
        ax[i].scatter(x,y,c=c[i])
        ax[i].scatter(centers[:,0],centers[:,1],c='r')
    ax[-1].set_title("Centers")
    ax[-1].scatter(centers[:,0],centers[:,1])
    plt.show()
    

    K均值算法有很多可以变化的地方:

    • 在求新的聚类中心时,可以直接修改旧的聚类中心
      这样类似于迭代法求解线性方程组时的“高斯-赛德尔”迭代法。
      这样也可以节省一点点空间,不过没必要。
    • 点之间距离的计算
      可以用差向量的范数,也可以用余弦距离。

    K均值算法可以用于分类。
    首先指定聚类个数K,执行聚类算法得到K个聚类,给这K个聚类进行打标签(也就是进行投票,这相当于K近邻算法的投票阶段),预测时计算测试样本离哪个聚类最近,就表示该测试样本的类别。
    这样做的好处是,吸收了K近邻的优点,并且降低了时间复杂度(K近邻需要计算测试样本与N个训练样本之间的距离,K均值分类只需要计算测试样本与K个聚类中心之间的距离)。
    特别地,当聚类个数K=N的时候,K均值分类就变成了K近邻分类。

    下面分析一下KMeans的时空复杂度。
    N个样本,每个样本M个属性,聚类个数为K
    空间复杂度为O(K*M),只需要存储下来中心点即可
    时间复杂度为

    • 更新各点的label,复杂度为O(NKM),需要计算N*K次长度为M的向量模长
    • 重新求中心距离,复杂度为O(N*M),需要计算N次长度为M的向量之和

    所以,总的时间复杂度为O(NKM)

  • 相关阅读:
    百度指数感想
    冲刺贡献分
    冲刺三
    通过myEclipse创建hibernate的实体类
    并发处理
    数据库设计原则(转载)
    Extjs学习
    关于oracle存储过程需要注意的问题
    oracle存储过程
    十大编程算法
  • 原文地址:https://www.cnblogs.com/weiyinfu/p/7928968.html
Copyright © 2011-2022 走看看