zoukankan      html  css  js  c++  java
  • Knn和meanShift学习——注释就是笔记

    Knn是需要监督的学习,也需要规定有几个特征

    meanShift是不需要监督,也不需要规定类别的,他是通过算法自动位移归类的

    # 加载数据
    import pandas as pd
    import numpy as np
    from matplotlib import pyplot as plt
    from sklearn.cluster import KMeans
    from sklearn.metrics import accuracy_score
    from sklearn.neighbors import  KNeighborsClassifier
    if __name__ == '__main__':
    
        data = pd.read_csv('data/data2.csv')
        data.head()
        # define X and y
        X = data.drop(['labels'], axis=1)      # axis意思是按照列
        y = data.loc[:, 'labels']            # :->所有行   ‘labels’这一列
    
    
        pd.value_counts(y)    # 这是察看这一列所有的值和出现次数
    
    
        fig1 = plt.figure()
        label0 =  plt.scatter(X.loc[:,'V1'][y==0],X.loc[:,'V2'][y==0])
        label1 =  plt.scatter(X.loc[:,'V1'][y==1],X.loc[:,'V2'][y==1])
        label2 =  plt.scatter(X.loc[:,'V1'][y==2],X.loc[:,'V2'][y==2])
        plt.title('un-labled data')
        plt.xlabel('V1')
        plt.ylabel('V2')
        plt.legend((label0,label1,label2),('label0','label1,','label2'))
        plt.show()
    
        # 创建模型
    
        # 得到中心点  和原来的图像放在一起观看
    
        # 预测结果 KM.predict([[80,60]])
    
        # 发现结果对但是标识顺序乱了
    
        #计算准确率   你会发现低的可怜  这是因为你的标识和颜色没有对上,需要修改
    
       # 无监督的学习 是没有标识的 需要修改
    
    
    
        # 创建Knn模型
        KNN = KNeighborsClassifier(n_neighbors=3)
        KNN.fit(X,y)
        y_predict_knn = KNN.predict(X)
        print('knn_accuracy:',accuracy_score(y,y_predict_knn))
    
        '''
            这时候不需要额外的矫正
        '''
    
        # centers = KNN.cluster_centers_
        fig4 = plt.subplot(121)
        label0 = plt.scatter(X.loc[:, 'V1'][y_predict_knn == 0], X.loc[:, 'V2'][y_predict_knn == 0])
        label1 = plt.scatter(X.loc[:, 'V1'][y_predict_knn == 1], X.loc[:, 'V2'][y_predict_knn == 1])
        label2 = plt.scatter(X.loc[:, 'V1'][y_predict_knn == 2], X.loc[:, 'V2'][y_predict_knn == 2])
        plt.title('predict data')
        plt.xlabel('V1')
        plt.ylabel('V2')
        plt.legend((label0, label1, label2), ('label0', 'label1,', 'label2'))
        # plt.scatter(centers[:, 0], centers[:, 1])
    
        fig5 = plt.subplot(122)
        label0 = plt.scatter(X.loc[:, 'V1'][y == 0], X.loc[:, 'V2'][y == 0])
        label1 = plt.scatter(X.loc[:, 'V1'][y == 1], X.loc[:, 'V2'][y == 1])
        label2 = plt.scatter(X.loc[:, 'V1'][y == 2], X.loc[:, 'V2'][y == 2])
        plt.title('un-labled data')
        plt.xlabel('V1')
        plt.ylabel('V2')
        plt.legend((label0, label1, label2), ('label0', 'label1,', 'label2'))
        # plt.scatter(centers[:, 0], centers[:, 1])
        plt.show()
    
    
        # meanshift
        from sklearn.cluster import MeanShift,estimate_bandwidth
        #obtain the bandwidth
        bw = estimate_bandwidth(X,n_samples=500)  # 数据集,和想要通过多少个样本进行估算
        print(bw)
        # 创建模型
        ms = MeanShift(bandwidth=bw)
        ms.fit(X)
        y_predict_ms = ms.predict(X)
        print(pd.value_counts(y_predict_ms),pd.value_counts(y))
    
    
    
        fig6 = plt.subplot(121)
        label0 = plt.scatter(X.loc[:, 'V1'][y_predict_ms == 0], X.loc[:, 'V2'][y_predict_ms == 0])
        label1 = plt.scatter(X.loc[:, 'V1'][y_predict_ms == 1], X.loc[:, 'V2'][y_predict_ms == 1])
        label2 = plt.scatter(X.loc[:, 'V1'][y_predict_ms == 2], X.loc[:, 'V2'][y_predict_ms == 2])
        plt.title('predict data')
        plt.xlabel('V1')
        plt.ylabel('V2')
        plt.legend((label0, label1, label2), ('label0', 'label1,', 'label2'))
        # plt.scatter(centers[:, 0], centers[:, 1])
    
        fig7 = plt.subplot(122)
        label0 = plt.scatter(X.loc[:, 'V1'][y == 0], X.loc[:, 'V2'][y == 0])
        label1 = plt.scatter(X.loc[:, 'V1'][y == 1], X.loc[:, 'V2'][y == 1])
        label2 = plt.scatter(X.loc[:, 'V1'][y == 2], X.loc[:, 'V2'][y == 2])
        plt.title('un-labled data')
        plt.xlabel('V1')
        plt.ylabel('V2')
        plt.legend((label0, label1, label2), ('label0', 'label1,', 'label2'))
        # plt.scatter(centers[:, 0], centers[:, 1])
        plt.show()
        # 无监督的学习 是没有标识的 需要修改
        y_corrected_ms = []
        for i in y_predict_ms:
            if i == 0:
                y_corrected_ms.append(2)
            elif i == 1:
                y_corrected_ms.append(1)
            else:
                y_corrected_ms.append(0)
  • 相关阅读:
    极致平台开发技巧介绍1如何利用升级包,快速给客户升级
    如何用极致业务基础平台做一个通用企业ERP系列之三启用期间管理设计
    caca需要用到x11作为图形输出
    spring boot 使用 mybatis 开启事务回滚 的总结
    RabbitMQ --- 直连交换机 【 同步操作,等到消费者处理完后返回处理结果 】
    RabbitMQ --- 直连交换机 【 有回调方法,获取消费结果 】
    RabbitMQ --- 直连交换机 【 无回调方法,不能获取消费结果 】
    spring boot 启动警告 WARN 15684 --- [ restartedMain] c.n.c.sources.URLConfigurationSource : No URLs will be polled as dynamic configuration sources. 解决
    Java基础复习到此结束,统一把源码放到GitHub仓库了,响应开源精神
    用一道题 来 复习 MySQL 的 复杂 sql 语句
  • 原文地址:https://www.cnblogs.com/chaogehahaha/p/15438650.html
Copyright © 2011-2022 走看看