zoukankan      html  css  js  c++  java
  • Meanshift均值漂移算法

     
     
    通俗理解Meanshift均值漂移算法 
    Meanshift车手?? 漂移?? 秋名山???   不,不,他是一组算法,  今天我就带大家来了解一下机器学习中的Meanshift均值漂移.
    Meanshift算法他的本质是一个迭代的过程 , 我先给大家讲一下他的底层原理
     
     
    1)概述
    Mean-shift(均值迁移)的基本思想:在数据集中选定一个点,然后以这个点为圆心,r为半径,画一个圆(二维下是圆),求出这个点到所有点的向量的平均值,而圆心与向量均值的和为新的圆心,然后迭代此过程,直到满足一点的条件结束。
    后来在此基础上加入了 核函数 和 权重系数 ,使得Mean-shift 算法开始流行起来。目前它在聚类、图像平滑、分割、跟踪等方面有着广泛的应用。
     
    2) 图解过程
    为了方便大家理解,借用下几张图来说明Mean-shift的基本过程。
    第一张图有一个子中心点,她向四周最近的点开始寻找,找到圆心与向量均值的和为新的圆心,然后依次循环,直到满足条件,则不会再寻找其他圆心点
    3)Mean-shift 算法函数
    a)核心函数:sklearn.cluster.MeanShift(核函数:RBF核函数)
    由上图可知,圆心(或种子)的确定和半径(或带宽)的选择,是影响算法效率的两个主要因素。所以在sklearn.cluster.MeanShift中重点说明了这两个参数的设定问题。
    b)主要参数
    bandwidth :半径(或带宽),float型。如果没有给出,则使用sklearn.cluster.estimate_bandwidth计算出半径(带宽).(可选)
    seeds :圆心(或种子),数组类型,即初始化的圆心。(可选)
    bin_seeding :布尔值。如果为真,初始内核位置不是所有点的位置,而是点的离散版本的位置,其中点被分类到其粗糙度对应于带宽的网格上。将此选项设置为True将加速算法,因为较少的种子将被初始化。默认值:False.如果种子参数(seeds)不为None则忽略。
    c)主要属性
    cluster_centers_ : 数组类型。计算出的聚类中心的坐标。
    labels_ :数组类型。每个数据点的分类标签。
     
    4)代码详解  这里用到的是一组贝叶斯数据
     
    #分割数据集,拆分数据

    #坐标轴负一问题
    plt.rcParams['axes.unicode_minus'] =False
    #分割数据集
    from sklearn.model_selection import train_test_split
    data=pd.read_csv('./贝叶斯.csv',header=None)
    print(data.shape) #显示几行几列

    #拆分数据
    dataset_X,dataset_y =data.iloc[:,:-1],data.iloc[:,-1]
    # print(dataset_X.head())

    ## 将pandas转为np.ndarray 可以用dataset = df.as_matrix()
    dataset_X =dataset_X.values
    dataset_y =dataset_y.values

    #估算带宽
    from sklearn.cluster import estimate_bandwidth,MeanShift
    # estimate_bandwidth有估计带宽的意思 n_clusters聚类的个数 quantile分位数,分位点
    bandwidth = estimate_bandwidth(dataset_X,quantile=0.1,n_samples=len(dataset_X))
    #打印出带宽
    print(bandwidth).

    #初始化聚类模型 band带宽 bin_seeding网格化数据点(加速模型)
    meanshift = MeanShift(bandwidth=bandwidth,bin_seeding=True)
    # 训练模型
    meanshift.fit(dataset_X)
    print(meanshift.cluster_centers_)
    print(meanshift.labels_)

    此时打印除掉数据如下,

    #最后一步,将图形绘制出,查看一下效果

    def visual_meanshift_effect(meanshift,dataset):
    assert dataset.shape[1]==2,'only support dataset with 2 features'
    X=dataset[:,0]
    Y=dataset[:,1]
    X_min,X_max=np.min(X)-1,np.max(X)+1
    Y_min,Y_max=np.min(Y)-1,np.max(Y)+1
    X_values,Y_values=np.meshgrid(np.arange(X_min,X_max,0.01),
    np.arange(Y_min,Y_max,0.01))
    # 预测网格点的标记
    predict_labels=meanshift.predict(np.c_[X_values.ravel(),Y_values.ravel()])
    predict_labels=predict_labels.reshape(X_values.shape)
    plt.figure()
    plt.imshow(predict_labels,interpolation='nearest',
    extent=(X_values.min(),X_values.max(),
    Y_values.min(),Y_values.max()),
    cmap=plt.cm.Paired,
    aspect='auto',
    origin='lower')

    # 将数据集绘制到图表中
    plt.scatter(X,Y,marker='v',facecolors='none',edgecolors='k',s=30)

    # 将中心点绘制到图中
    centroids=meanshift.cluster_centers_
    plt.scatter(centroids[:,0],centroids[:,1],marker='o',
    s=100,linewidths=2,color='k',zorder=5,facecolors='b')
    plt.title('MeanShift effect graph')
    plt.xlim(X_min,X_max)
    plt.ylim(Y_min,Y_max)
    plt.xlabel('feature_0')
    plt.ylabel('feature_1')
    plt.show()
    visual_meanshift_effect(meanshift,dataset_X)

     
     
     
     
     
  • 相关阅读:
    使用百度网盘配置私有Git服务
    Linked dylibs built for GC-only but object files built for retain/release for architecture x86_64
    我的博客搬家啦!!!
    今日头条核心业务(高级)开发工程师,直接推给部门经理,HC很多,感兴趣的可以一起聊聊。
    学习Python的三种境界
    拿到阿里,网易游戏,腾讯,smartx的offer的过程
    关于计算机网络一些问题的思考
    网易游戏面试经验(三)
    网易游戏面试经验(二)
    网易游戏面试经验(一)
  • 原文地址:https://www.cnblogs.com/lowbi/p/10733733.html
Copyright © 2011-2022 走看看