zoukankan      html  css  js  c++  java
  • 机器学习中均值漂移

    import pandas as pd
    import numpy as np
    import matplotlib
    from matplotlib import pyplot as plt

    %matplotlib inline
    #指定默认字体
    matplotlib.rcParams['font.sans-serif'] = ['SimHei']
    #用来正常显示负号
    plt.rcParams['axes.unicode_minus']=False
    #读取文件夹
    data = pd.read_csv('010-data_multivar.csv',header=None)

    # 拆分数据
    dataset_X,dataset_y= data.iloc[:,:-1],data.iloc[:,-1]
    dataset_X = dataset_X.values
    dataset_y = dataset_y.values

    #估算带宽
    from sklearn.cluster import estimate_bandwidth,MeanShift
    #数据集 dataset_X样本 quantile:分位数 n_samples:使用样本大小
    bandwidth = estimate_bandwidth(dataset_X,quantile=0.1,n_samples=len(dataset_X))
    #format()把求出的数据传到{}里面
    print('带宽:{}'.format(bandwidth))

    #初始化聚类模型,带宽,网格化数据点(加速模型速度) bin_seeding:加速模型
    meanshift = MeanShift(bandwidth=bandwidth,bin_seeding=True)
    meanshift.fit(dataset_X)

    print(meanshift.cluster_centers_) #获取所有质心坐标
    print(meanshift.labels_) #获取所有标签

    数据可视化

    def visual_meanshift_effect(meanshift,dataset):
    assert dataset.shape[1]==2,'only support dataset with 2 features'
    X=dataset[:,0]
    Y=dataset[:,1]
    X_min,X_max=np.min(X)-1,np.max(X)+1
    Y_min,Y_max=np.min(Y)-1,np.max(Y)+1
    X_values,Y_values=np.meshgrid(np.arange(X_min,X_max,0.01),
    np.arange(Y_min,Y_max,0.01))
    # 预测网格点的标记
    predict_labels=meanshift.predict(np.c_[X_values.ravel(),Y_values.ravel()])
    predict_labels=predict_labels.reshape(X_values.shape)
    plt.figure()
    plt.imshow(predict_labels,interpolation='nearest',
    extent=(X_values.min(),X_values.max(),
    Y_values.min(),Y_values.max()),
    cmap=plt.cm.Paired,
    aspect='auto',
    origin='lower')

    # 将数据集绘制到图表中
    plt.scatter(X,Y,marker='v',facecolors='none',edgecolors='k',s=30)

    # 将中心点绘制到图中
    centroids=meanshift.cluster_centers_
    plt.scatter(centroids[:,0],centroids[:,1],marker='o',
    s=100,linewidths=2,color='k',zorder=5,facecolors='b')
    plt.title('MeanShift effect graph')
    plt.xlim(X_min,X_max)
    plt.ylim(Y_min,Y_max)
    plt.xlabel('feature_0')
    plt.ylabel('feature_1')
    plt.show()

    visual_meanshift_effect(meanshift,dataset_X)

  • 相关阅读:
    手动启动log4j|nginx实现http https共存
    java.util.zip.ZipException: invalid LOC header (bad signature)
    Bean property 'transactionManagerBeanName' is not writable or has an invalid set
    rabbitmq启动异常table_attributes_mismatch
    nexus私服快速update index方法
    Spring boot ,dubbo整合异常
    如何编写无须人工干预的shell脚本
    Jenkins构建部署jar/war后,服务无法在后台持续运行的解决方案
    移动端CSS通用样式
    Spring bean的几种装配方式
  • 原文地址:https://www.cnblogs.com/antique/p/10728169.html
Copyright © 2011-2022 走看看