zoukankan      html  css  js  c++  java
  • Mean shift 【1】- 基本原理

    Mean shift 向量

    中文名叫 均值偏移向量,定义如下:

    在一个 n 维空间内,存在一个点 x,以该点为球心,以 h 为半径,生成一个球 Sh,计算球心到球内所有点生成的向量的均值,显然这个均值也是一个向量,这个向量就是 mean shift 向量

    公式如下

              【用 球心 计算 质心】

    图示如下 

     

    shift point 

    我给起个名字,叫 偏移点;

    注意,几乎没有资料专门提到这个概念,我为什么要讲呢?因为我们需要把 shift point 和 mean shift 区分开,这俩可不是一回事;

    而在我们写算法时,需要的是 shift point,而不是 mean shift

        def _shift_point(self, point, points, kernel_bandwidth):
            shift_x = 0.0
            shift_y = 0.0
            scale = 0.0
            for p in points:
                dist = distance(point, p)
                weight = self.kernel(dist, kernel_bandwidth)
                if dist == 0: print(weight)
                ### shift point
                shift_x += p[0] * weight
                shift_y += p[1] * weight
                ### 而不是 mean shift
                # shift_x += (p[0] - point[0]) * weight
                # shift_y += (p[1] - point[1]) * weight
                scale += weight
            shift_x = shift_x / scale
            shift_y = shift_y / scale
            return [shift_x, shift_y]

    那 shift point 到底是什么,又该如何计算呢,直接上图

    再来一张

     

    Mean shift 算法 

    基本思想

    对于 样本中的每一个点 x,做如下操作

    1. 以 x 为起点,计算他的 shift point  x‘,然后把 该点 “移动” 到 x’      【注意不是真的移动点,而是把 x 标记成 x’

    2. 以 x’ 为新起点,计算他的 shift point

    3. 重复 前两步,直至 前后两次 的 mean shift 向量满足条件,如 距离很近  【这一步才用到 mean shift,也就是 前后两个 shift point 相减得到 向量,再计算向量的模】

    4. 把 x 标记为 最终的 shift point,即为对应的类

    5. 遍历计算所有点

    过程大致如下图

    从上图可以看到,mean shift 向量指向了更密集的区域,也就是说 mean shift 算法是在寻找 最密集 的区域,作为最后的类别

    存在问题

    在计算 mean shift 向量时,圆圈内所有点的贡献是一样的 即1/k,而实际上离圆心越远可能贡献越小,

    为此 mean shift 算法引入核函数来表达这种贡献,代替 1/k

    引入核函数 

    核函数 参考 我的博客

    此处以 高斯核函数 为例

    其中 h 代表核函数的带宽(bandwidth)      【这个 h 和 高斯分布 里的 标准差σ 类似,但它不是 标准差,而是 人工指定的,但是起到的作用和 标准差一样】

    不同带宽的核函数表示如下

     

    在 h 一定时, x 离 均值(圆心)越远,函数值越小,体现到 mean shift 向量中,就是 贡献越小;    【高斯滤波不也是这样吗,那高斯滤波也可以引入核函数了】

    h 越小,衰减为 0 的速度就越快,也就是说 mean shift 向量对应的球 S越小,稍微远点就没有贡献了

    于是,引入 核函数 的 mean shift 向量变成如下样子,此时的 S可以为整个数据集(原因为上句)

    Mean shift VS KMeans

    1. KMeans 需要设置 k,mean shift 无需

    2. 实际工作中复杂数据用 mean shift 无法控制 k 个值,可能会产生过多的类而导致聚类失去意义

    示例代码

    import random
    import numpy as np
    import matplotlib.pyplot as plt
    from sklearn.datasets import make_blobs, make_moons
    
    
    # 定义 预先设定 的阈值
    STOP_THRESHOLD = 1e-4
    CLUSTER_THRESHOLD = 1e-1
    
    # 定义度量函数
    def distance(a, b):
        return np.linalg.norm(np.array(a) - np.array(b))
    
    # 定义高斯核函数
    def gaussian_kernel(distance, bandwidth):
        return (1 / (bandwidth * np.sqrt(2 * np.pi))) * np.exp(-0.5 * ((distance / bandwidth)) ** 2)
    
    
    # mean_shift类
    class mean_shift(object):
        def __init__(self, kernel=gaussian_kernel):
            self.kernel = kernel
    
        def fit(self, points, kernel_bandwidth):
    
            shift_points = np.array(points)     # 初始化偏移点
            shifting = [True] * points.shape[0] # 所有点都需要偏移
    
            ######## while 循环用来执行针对所有点的 多轮偏移
            while True:
                max_dist = 0
                ######## for 循环用来执行针对所有点的 一轮偏移
                for i in range(0, len(shift_points)):
                    if not shifting[i]:     # 是否需要偏移
                        continue
                    ##### 如果需要偏移,先移动一次,   while 循环保证一直会移动
                    p_shift_init = shift_points[i].copy()   # 获取一个点
                    ### 下面这句已经把原来的点 偏移到 偏移点了
                    shift_points[i] = self._shift_point(shift_points[i], points, kernel_bandwidth)
                    dist = distance(shift_points[i], p_shift_init)  # 偏移点 和 原来的点 的距离
    
                    max_dist = max(max_dist, dist)      # 取最大距离
                    # 距离大于停止条件,继续移动, 距离 小于 停止条件,结束移动
                    shifting[i] = dist > STOP_THRESHOLD 
                    ### 至此 一轮偏移 结束
                
                ### 每轮 偏移后 取 所有偏移向量得 最大值,
                ### 如果 小于 停止条件,说明所有点都 偏移到 最后的 点了,多伦偏移可以结束了
                if (max_dist < STOP_THRESHOLD):     
                    break
            
            ## shift_points 就是 每个点对应的 最终 shift point
            cluster_ids = self._cluster_points(shift_points.tolist())
            return shift_points, cluster_ids
    
        def _shift_point(self, point, points, kernel_bandwidth):
            # point 球心, points 球内点,计算 均值偏移向量
            shift_x = 0.0
            shift_y = 0.0
            scale = 0.0
            for p in points:
                print(point)
                dist = distance(point, p)
                weight = self.kernel(dist, kernel_bandwidth)
                # shift point 
                shift_x += p[0] * weight
                shift_y += p[1] * weight
                scale += weight
            shift_x = shift_x / scale
            shift_y = shift_y / scale
            return [shift_x, shift_y]
    
        def _cluster_points(self, points):
            cluster_ids = []
            cluster_idx = 0
            cluster_centers = []
    
            for i, point in enumerate(points):
                if (len(cluster_ids) == 0):
                    cluster_ids.append(cluster_idx)
                    cluster_centers.append(point)
                    cluster_idx += 1
                else:
                    for center in cluster_centers:
                        dist = distance(point, center)
                        if (dist < CLUSTER_THRESHOLD):
                            cluster_ids.append(cluster_centers.index(center))
                    if (len(cluster_ids) < i + 1):
                        cluster_ids.append(cluster_idx)
                        cluster_centers.append(point)
                        cluster_idx += 1
            return cluster_ids
    
    def colors(n):
        ret = []
        for i in range(n):
            ret.append((random.uniform(0, 1), random.uniform(0, 1), random.uniform(0, 1)))
        return ret
    
    def main():
        centers = [[-1, -1], [-1, 1], [1, -1], [1, 1]]
        X, _ = make_blobs(n_samples=300, centers=centers, cluster_std=0.4)  # h=0.5
        # X, y = make_moons(n_samples=200, noise=0.05, random_state=0)        # h
    
        mean_shifter = mean_shift()
        _, mean_shift_result = mean_shifter.fit(X, kernel_bandwidth=0.2)
    
        np.set_printoptions(precision=3)
        print('input: {}'.format(X))
        print('assined clusters: {}'.format(mean_shift_result))
        color = colors(np.unique(mean_shift_result).size)
    
        for i in range(len(mean_shift_result)):
            plt.scatter(X[i, 0], X[i, 1], color=color[mean_shift_result[i]])
        plt.show()
    
    
    if __name__ == '__main__':
        main()

    输出

    参考资料:

    https://zhuanlan.zhihu.com/p/81629406  机器学习-Mean Shift聚类算法

    https://www.biaodianfu.com/mean-shift.html  机器学习聚类算法之Mean Shift

    https://www.cnblogs.com/liqizhou/archive/2012/05/12/2497220.html  Meanshift,聚类算法

    https://blog.csdn.net/u014661698/article/details/84979979  聚类算法之meanshift

    https://blog.csdn.net/moge19/article/details/85346528 

    https://www.jb51.net/article/188375.htm  python实现mean-shift聚类算法

  • 相关阅读:
    【转】每天一个linux命令(41):ps命令
    【转】每天一个linux命令(40):wc命令
    【转】每天一个linux命令(39):grep 命令
    【转】每天一个linux命令(38):cal 命令
    【转】每天一个linux命令(37):date命令
    【转】每天一个linux命令(36):diff 命令
    【转】每天一个linux命令(35):ln 命令
    【转】每天一个linux命令(34):du 命令
    诗词、对联名句(千古名帖)
    诗词、对联名句(千古名帖)
  • 原文地址:https://www.cnblogs.com/yanshw/p/14931658.html
Copyright © 2011-2022 走看看