zoukankan      html  css  js  c++  java
  • K-Means 算法

    认识

    K-Means 是属于聚类算法中的一种, 聚类算法呢, 是属于 无监督学习. 不需要数据的标签(label). 主要用途是为了发现数据中的规律(模式), 就咱平时说的数据挖掘. 使用的场景, 从营销领域来看, 可用来做用户细分, 行为聚类, 精准营销, 推荐系统等, 当然在其他领域,诸如图像分割, 生物研究也可以的.

    无监督: 就是只计算特征之间的关系不关注label. 场景的,如 聚类, PCA...

    一言蔽之, 物以类聚, 人以群分. 核心是把相似的物体聚在一起, 处理上则是计算物体间的相似度.

    K-Means 算法

    是一种循环迭代式的算法.

    初始化

    • 随机选择 K 个点, 作为初始点的中心, 每个点作为一个 group.

    交替更新

    • 计算每个点到中心点的距离, 把最近的距离记录并将group赋给当前的点
    • 针对每一个group的点, 计算其平均并作为这个group的新中心点

    循环上两步即可, 从代码的角度讲, 停止循环有两个注意点

    • 外层循环,设置一个最大循环次数, 比如100次.
    • 计算中心点, 前后两次位置不变 (完全分开了呀)

    伪代码

    为了说明意思, 搬的伪代码哈, 有空再好好写一波.

    def dist_eclud(arr1, arr2):
        """计算两个点之间的欧式距离"""
        return np.sqrt(np.sum((arr1 - arr2)**2))
    
    def random_center(X, k): 
        """随机初始化k个中心点,X 是(n, p)维的样本"""
        # 获取样本矩阵的行列数
        n, p = np.shape(X)
        #创建 k 行 p 列的全为0 的矩阵
        center = np.matrix(np.zeros([k, p]))  
        for j in range(p):
            min_j = np.min(X[:,j])
             # 计算极值
            range_j = float(np.max(X[:,j]) - min_j)    
            # 得到 k 个随机中心点, 每个点是p维度
            center[:,j] = np.mat(min_j + range_j * np.random.rand(k, 1))   
        return center
    
    def fit(X, k):
        # 获取样本数(行数)
        n,p = np.shape(X)
        # 创建一个n行p列的矩阵
        cluster = np.mat(np.zeros([n,p]))
        # 随机初始化k个中心点,每个点是p维
        center = random_center(X, k)
        
        while True:
            for i in range(n):
                # 初始化最小距离 (默认非常大)
                min_dist = np.inf
                # 中心点下标
                min_index = 0
                # 计算k个中心点到该样本的距离
                for j in range(k):
                    dist_j = dist_eclud(center[j,:], X[i,:])
                    # 更新最小距离
                    if dist_j < min_dist:
                        min_dist = dist_j
                        min_index = j 
                # 判断: 如果中心点前后没有变化,说明不能再继续分, 则终止循环
                if cluster[i, 0] != min_index:
                    break
                # 将 index, k中心点, 最小距离存到数组
                cluster[i, :] = min_index, min_dist**2
            # 打印中心点
            print(center)
            
            # 更新中心点的位置
            for i in range(k):
                # 平铺cluster 为一个一维数组, 分别找到属于k类的数据
                points = dataSet[np.nonzero(cluster[:,0].A == i)[0]] 
                #得到更新后的中心点
                center[i,:] = np.mean(points, axis = 0)  
        return center, cluster 
                
    

    K-Means 算法特性

    当然是考量每一次迭代的复杂度了.

    第一步的选取 k 个点, 这里不考虑了, 关键是计算和迭代的部分.

    首先,计算每个点到中心点的距离, 把最近的距离记录下来并把group赋给当前点. 假设有 n个样本点.则此过程的时间复杂度为:

    (O(kn))

    因为有 k 个中心点. 每次计要把 n 个点 计算到 k 个点的距离

    然后, 针对每一个 group 里的点, 计算平均作为这个 group 的新中心点. 则此过程的时间复杂度为:

    (O(n))

    K-means 的几点思考

    目标函数

    已知观测集 ((x_1, x_2, ... x_n)) 其中每个观测 xi 都是 d 维的实向量. k-平均聚类, 就是要把 这n个观测划分到 k 个集合中 ( k <= n) , 使得簇内平方和最小. 即找到使得下式满足的聚类 (S_i)

    (arg min _S sum limits _{i=1}^k sum limits _{x in S_i} ||X-mu_i||^2)

    (mu_i)(S_i) 类中的所有点的均值

    (sum limits _{x in S_i} ||X-mu_i||^2) 表示类别为 i 的样本点, 到中心的距离尽可能小

    外层的求和, 是对每个类到其自身中心的距离, 总体上达到最小

    关键: 是如何找到这些 xi

    是否一定会收敛

    是的

    首先, 将N个数据分为 k 个聚类, 最多有 (k^N) 种可能, 这是一个有限的数值. 对于算法迭代, 我们仅基于旧的聚类产生一个新的聚类.

    • 如果旧的聚类和新的聚类相同, 则下一次迭代的结果也再次相同
    • 新,旧不同, 则更新的群集的目标函数值较低

    其次, 算法的不同迭代, 最终会进入一个循环, 即k均值在有限的迭代次数中收敛

    不同初始化对结果的影响

    不同初始点,会带来不同的结果

    K 如何选择

    inertia: 群集中,所有点离其所属cluster中心的距离的总和.

    不论数据集如何, 随着k的增加, inertia 呈递减趋势, 开始速度很快, 慢慢变小, 会存在一个拐点.

    极限情况是, 每一个点都各自一类, 则总距离为0了, 也就没有了所谓 "聚类''的存在

    KMeans ++

    • 从数据集中随机选择一个中心
    • 对于每个 xi, 计算与**已经选择的最接近中心之间的距离 **D(xi)
    • 使用加权概率分布随机选择一个新的数据点作为新中心, 该点的概率与 (D(x_i)^2) 成正比
    • 重复步骤 2,3 直到选择了 k 个中心
    • 根据选择好了的初始中心, 继续使用k-means 聚类

    这种算法, 其实就是在选择初始中心的时候, 尽可能 "分布在数据四周" , 后面不变.

    KMeans 优缺点

    优点

    • 容易理解和实现
    • 结果的解释性强
    • 计算复杂度低

    缺点

    • 异常值 和 初始化 非常敏感

    • 对数据的 分布不均匀 条件下, 效果不太理想 (样本不均衡就有点麻烦)

    • 默认前提假设是特征之间的联合分布是椭圆形的, S形, 环形就凉凉了 ( svm 线性不可分)

    • 即便 k 给定, 聚类的结果也不唯一. 最好要多调试几次, 选择 inertia 最小时的参数

  • 相关阅读:
    《MySQL必知必会》第二十三章:使用存储过程
    《MySQL必知必会》第二十四章:使用游标
    《MySQL必知必会》第二十五章:使用触发器
    《MySQL必知必会》第二十章:更新和删除数据
    《MySQL必知必会》第二十一章:创建和操纵表
    《MySQL必知必会》第二十二章:使用视图
    《MySQL必知必会》第十七章:组合查询
    《MySQL必知必会》第十八章:全文本搜索
    [LeetCode] 930. Binary Subarrays With Sum
    [LeetCode] 676. Implement Magic Dictionary
  • 原文地址:https://www.cnblogs.com/chenjieyouge/p/12041346.html
Copyright © 2011-2022 走看看