zoukankan      html  css  js  c++  java
  • KNN笔记

                                                                                            KNN笔记

    先简单加载一下sklearn里的数据集,然后再来讲KNN。

    1 import numpy as np
    2 import matplotlib as mpl
    3 import matplotlib.pyplot as plt
    4 from sklearn import datasets
    5 iris=datasets.load_iris()

    看一下鸢尾花的keys:

    iris.keys()

    结果是:

    dict_keys(['data', 'target', 'target_names', 'DESCR', 'feature_names'])

    看一下文档:

    print(iris.DESCR) #看看文档

    文档结果:

    Iris Plants Database
    ====================
    
    Notes
    -----
    Data Set Characteristics:
        :Number of Instances: 150 (50 in each of three classes)
        :Number of Attributes: 4 numeric, predictive attributes and the class
        :Attribute Information:
            - sepal length in cm
            - sepal width in cm
            - petal length in cm
            - petal width in cm
            - class:
                    - Iris-Setosa
                    - Iris-Versicolour
                    - Iris-Virginica
        :Summary Statistics:
    
        ============== ==== ==== ======= ===== ====================
                        Min  Max   Mean    SD   Class Correlation
        ============== ==== ==== ======= ===== ====================
        sepal length:   4.3  7.9   5.84   0.83    0.7826
        sepal     2.0  4.4   3.05   0.43   -0.4194
        petal length:   1.0  6.9   3.76   1.76    0.9490  (high!)
        petal     0.1  2.5   1.20  0.76     0.9565  (high!)
        ============== ==== ==== ======= ===== ====================
    
        :Missing Attribute Values: None
        :Class Distribution: 33.3% for each of 3 classes.
        :Creator: R.A. Fisher
        :Donor: Michael Marshall (MARSHALL%PLU@io.arc.nasa.gov)
        :Date: July, 1988
    
    This is a copy of UCI ML iris datasets.
    http://archive.ics.uci.edu/ml/datasets/Iris
    
    The famous Iris database, first used by Sir R.A Fisher
    
    This is perhaps the best known database to be found in the
    pattern recognition literature.  Fisher's paper is a classic in the field and
    is referenced frequently to this day.  (See Duda & Hart, for example.)  The
    data set contains 3 classes of 50 instances each, where each class refers to a
    type of iris plant.  One class is linearly separable from the other 2; the
    latter are NOT linearly separable from each other.
    
    References
    ----------
       - Fisher,R.A. "The use of multiple measurements in taxonomic problems"
         Annual Eugenics, 7, Part II, 179-188 (1936); also in "Contributions to
         Mathematical Statistics" (John Wiley, NY, 1950).
       - Duda,R.O., & Hart,P.E. (1973) Pattern Classification and Scene Analysis.
         (Q327.D83) John Wiley & Sons.  ISBN 0-471-22361-1.  See page 218.
       - Dasarathy, B.V. (1980) "Nosing Around the Neighborhood: A New System
         Structure and Classification Rule for Recognition in Partially Exposed
         Environments".  IEEE Transactions on Pattern Analysis and Machine
         Intelligence, Vol. PAMI-2, No. 1, 67-71.
       - Gates, G.W. (1972) "The Reduced Nearest Neighbor Rule".  IEEE Transactions
         on Information Theory, May 1972, 431-433.
       - See also: 1988 MLC Proceedings, 54-64.  Cheeseman et al"s AUTOCLASS II
         conceptual clustering system finds 3 classes in the data.
       - Many, many more ...
    文档

    看一下数据data:

    iris.data #看看数据

    数据为:

      1 array([[ 5.1,  3.5,  1.4,  0.2],
      2        [ 4.9,  3. ,  1.4,  0.2],
      3        [ 4.7,  3.2,  1.3,  0.2],
      4        [ 4.6,  3.1,  1.5,  0.2],
      5        [ 5. ,  3.6,  1.4,  0.2],
      6        [ 5.4,  3.9,  1.7,  0.4],
      7        [ 4.6,  3.4,  1.4,  0.3],
      8        [ 5. ,  3.4,  1.5,  0.2],
      9        [ 4.4,  2.9,  1.4,  0.2],
     10        [ 4.9,  3.1,  1.5,  0.1],
     11        [ 5.4,  3.7,  1.5,  0.2],
     12        [ 4.8,  3.4,  1.6,  0.2],
     13        [ 4.8,  3. ,  1.4,  0.1],
     14        [ 4.3,  3. ,  1.1,  0.1],
     15        [ 5.8,  4. ,  1.2,  0.2],
     16        [ 5.7,  4.4,  1.5,  0.4],
     17        [ 5.4,  3.9,  1.3,  0.4],
     18        [ 5.1,  3.5,  1.4,  0.3],
     19        [ 5.7,  3.8,  1.7,  0.3],
     20        [ 5.1,  3.8,  1.5,  0.3],
     21        [ 5.4,  3.4,  1.7,  0.2],
     22        [ 5.1,  3.7,  1.5,  0.4],
     23        [ 4.6,  3.6,  1. ,  0.2],
     24        [ 5.1,  3.3,  1.7,  0.5],
     25        [ 4.8,  3.4,  1.9,  0.2],
     26        [ 5. ,  3. ,  1.6,  0.2],
     27        [ 5. ,  3.4,  1.6,  0.4],
     28        [ 5.2,  3.5,  1.5,  0.2],
     29        [ 5.2,  3.4,  1.4,  0.2],
     30        [ 4.7,  3.2,  1.6,  0.2],
     31        [ 4.8,  3.1,  1.6,  0.2],
     32        [ 5.4,  3.4,  1.5,  0.4],
     33        [ 5.2,  4.1,  1.5,  0.1],
     34        [ 5.5,  4.2,  1.4,  0.2],
     35        [ 4.9,  3.1,  1.5,  0.1],
     36        [ 5. ,  3.2,  1.2,  0.2],
     37        [ 5.5,  3.5,  1.3,  0.2],
     38        [ 4.9,  3.1,  1.5,  0.1],
     39        [ 4.4,  3. ,  1.3,  0.2],
     40        [ 5.1,  3.4,  1.5,  0.2],
     41        [ 5. ,  3.5,  1.3,  0.3],
     42        [ 4.5,  2.3,  1.3,  0.3],
     43        [ 4.4,  3.2,  1.3,  0.2],
     44        [ 5. ,  3.5,  1.6,  0.6],
     45        [ 5.1,  3.8,  1.9,  0.4],
     46        [ 4.8,  3. ,  1.4,  0.3],
     47        [ 5.1,  3.8,  1.6,  0.2],
     48        [ 4.6,  3.2,  1.4,  0.2],
     49        [ 5.3,  3.7,  1.5,  0.2],
     50        [ 5. ,  3.3,  1.4,  0.2],
     51        [ 7. ,  3.2,  4.7,  1.4],
     52        [ 6.4,  3.2,  4.5,  1.5],
     53        [ 6.9,  3.1,  4.9,  1.5],
     54        [ 5.5,  2.3,  4. ,  1.3],
     55        [ 6.5,  2.8,  4.6,  1.5],
     56        [ 5.7,  2.8,  4.5,  1.3],
     57        [ 6.3,  3.3,  4.7,  1.6],
     58        [ 4.9,  2.4,  3.3,  1. ],
     59        [ 6.6,  2.9,  4.6,  1.3],
     60        [ 5.2,  2.7,  3.9,  1.4],
     61        [ 5. ,  2. ,  3.5,  1. ],
     62        [ 5.9,  3. ,  4.2,  1.5],
     63        [ 6. ,  2.2,  4. ,  1. ],
     64        [ 6.1,  2.9,  4.7,  1.4],
     65        [ 5.6,  2.9,  3.6,  1.3],
     66        [ 6.7,  3.1,  4.4,  1.4],
     67        [ 5.6,  3. ,  4.5,  1.5],
     68        [ 5.8,  2.7,  4.1,  1. ],
     69        [ 6.2,  2.2,  4.5,  1.5],
     70        [ 5.6,  2.5,  3.9,  1.1],
     71        [ 5.9,  3.2,  4.8,  1.8],
     72        [ 6.1,  2.8,  4. ,  1.3],
     73        [ 6.3,  2.5,  4.9,  1.5],
     74        [ 6.1,  2.8,  4.7,  1.2],
     75        [ 6.4,  2.9,  4.3,  1.3],
     76        [ 6.6,  3. ,  4.4,  1.4],
     77        [ 6.8,  2.8,  4.8,  1.4],
     78        [ 6.7,  3. ,  5. ,  1.7],
     79        [ 6. ,  2.9,  4.5,  1.5],
     80        [ 5.7,  2.6,  3.5,  1. ],
     81        [ 5.5,  2.4,  3.8,  1.1],
     82        [ 5.5,  2.4,  3.7,  1. ],
     83        [ 5.8,  2.7,  3.9,  1.2],
     84        [ 6. ,  2.7,  5.1,  1.6],
     85        [ 5.4,  3. ,  4.5,  1.5],
     86        [ 6. ,  3.4,  4.5,  1.6],
     87        [ 6.7,  3.1,  4.7,  1.5],
     88        [ 6.3,  2.3,  4.4,  1.3],
     89        [ 5.6,  3. ,  4.1,  1.3],
     90        [ 5.5,  2.5,  4. ,  1.3],
     91        [ 5.5,  2.6,  4.4,  1.2],
     92        [ 6.1,  3. ,  4.6,  1.4],
     93        [ 5.8,  2.6,  4. ,  1.2],
     94        [ 5. ,  2.3,  3.3,  1. ],
     95        [ 5.6,  2.7,  4.2,  1.3],
     96        [ 5.7,  3. ,  4.2,  1.2],
     97        [ 5.7,  2.9,  4.2,  1.3],
     98        [ 6.2,  2.9,  4.3,  1.3],
     99        [ 5.1,  2.5,  3. ,  1.1],
    100        [ 5.7,  2.8,  4.1,  1.3],
    101        [ 6.3,  3.3,  6. ,  2.5],
    102        [ 5.8,  2.7,  5.1,  1.9],
    103        [ 7.1,  3. ,  5.9,  2.1],
    104        [ 6.3,  2.9,  5.6,  1.8],
    105        [ 6.5,  3. ,  5.8,  2.2],
    106        [ 7.6,  3. ,  6.6,  2.1],
    107        [ 4.9,  2.5,  4.5,  1.7],
    108        [ 7.3,  2.9,  6.3,  1.8],
    109        [ 6.7,  2.5,  5.8,  1.8],
    110        [ 7.2,  3.6,  6.1,  2.5],
    111        [ 6.5,  3.2,  5.1,  2. ],
    112        [ 6.4,  2.7,  5.3,  1.9],
    113        [ 6.8,  3. ,  5.5,  2.1],
    114        [ 5.7,  2.5,  5. ,  2. ],
    115        [ 5.8,  2.8,  5.1,  2.4],
    116        [ 6.4,  3.2,  5.3,  2.3],
    117        [ 6.5,  3. ,  5.5,  1.8],
    118        [ 7.7,  3.8,  6.7,  2.2],
    119        [ 7.7,  2.6,  6.9,  2.3],
    120        [ 6. ,  2.2,  5. ,  1.5],
    121        [ 6.9,  3.2,  5.7,  2.3],
    122        [ 5.6,  2.8,  4.9,  2. ],
    123        [ 7.7,  2.8,  6.7,  2. ],
    124        [ 6.3,  2.7,  4.9,  1.8],
    125        [ 6.7,  3.3,  5.7,  2.1],
    126        [ 7.2,  3.2,  6. ,  1.8],
    127        [ 6.2,  2.8,  4.8,  1.8],
    128        [ 6.1,  3. ,  4.9,  1.8],
    129        [ 6.4,  2.8,  5.6,  2.1],
    130        [ 7.2,  3. ,  5.8,  1.6],
    131        [ 7.4,  2.8,  6.1,  1.9],
    132        [ 7.9,  3.8,  6.4,  2. ],
    133        [ 6.4,  2.8,  5.6,  2.2],
    134        [ 6.3,  2.8,  5.1,  1.5],
    135        [ 6.1,  2.6,  5.6,  1.4],
    136        [ 7.7,  3. ,  6.1,  2.3],
    137        [ 6.3,  3.4,  5.6,  2.4],
    138        [ 6.4,  3.1,  5.5,  1.8],
    139        [ 6. ,  3. ,  4.8,  1.8],
    140        [ 6.9,  3.1,  5.4,  2.1],
    141        [ 6.7,  3.1,  5.6,  2.4],
    142        [ 6.9,  3.1,  5.1,  2.3],
    143        [ 5.8,  2.7,  5.1,  1.9],
    144        [ 6.8,  3.2,  5.9,  2.3],
    145        [ 6.7,  3.3,  5.7,  2.5],
    146        [ 6.7,  3. ,  5.2,  2.3],
    147        [ 6.3,  2.5,  5. ,  1.9],
    148        [ 6.5,  3. ,  5.2,  2. ],
    149        [ 6.2,  3.4,  5.4,  2.3],
    150        [ 5.9,  3. ,  5.1,  1.8]])
    数据data

    可见data为150行,每行4列的数据。

    看一下target:

    iris.target #看看对应的目标值

    target结果为:

    array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
           0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
           0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
           1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
           1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
           2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
           2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])

    看一下target_names:

    iris.target_names #看看目标值对应的目标名称

    arget_names结果为:

    array(['setosa', 'versicolor', 'virginica'],
          dtype='<U10')

    也就是target的0,1,2分别对应的鸢尾花的名称就是这三个。

    看一下4列数据(也就是data)分别是指什么

    iris.feature_names #看看四个数据对应的是什么

    可以看到结果为:

    ['sepal length (cm)',
     'sepal width (cm)',
     'petal length (cm)',
     'petal width (cm)']

    也就是4列数据分别代表花萼的长,花萼的宽,花瓣的长,花瓣的宽。

    看一下花萼的数据,也就是前两列的数据:

    1 #看一下花萼的散点图
    2 X=iris.data[:,:2] 
    3 plt.scatter(X[:,0],X[:,1])
    4 plt.xlabel("sepal length")
    5 plt.ylabel("sepal width")
    6 plt.title("DU's plot about speal")
    7 plt.show() 

      

    把三种花的散点图区分一下:

    1 #把三种花的花萼的散点图画出来
    2 y=iris.target
    3 plt.scatter(X[y==0,0],X[y==0,1],color='b')
    4 plt.scatter(X[y==1,0],X[y==1,1],color='r')
    5 plt.scatter(X[y==2,0],X[y==2,1],color='g')
    6 plt.xlabel("sepal length")
    7 plt.ylabel("sepal width")
    8 plt.title("DU's plot about speal")
    9 plt.show()

    再看一下花瓣的散点图:

    1 Petal=iris.data[:,2:]
    2 y=iris.target
    3 plt.scatter(Petal[y==0,0],Petal[y==0,1],color='b')
    4 plt.scatter(Petal[y==1,0],Petal[y==1,1],color='r')
    5 plt.scatter(Petal[y==2,0],Petal[y==2,1],color='g')
    6 plt.xlabel("Petal length")
    7 plt.ylabel("Petal width")
    8 plt.title("DU's plot about Petal")
    9 plt.show()

    看到花瓣的散点图,那么就说一下KNN,那现在假设,花瓣散点图里来了一个长度为2cm,宽度主0.5cm的一个点,那么这个点代表的是哪个鸢尾呢?一般的人就能推出这个点应该是跟蓝色点是一类的,因为新进来的点是离蓝色的区域最近的,而离其他的红色或者绿色区域都很远。那么,这就是KNN的一个思想了。

    比如现假设有如下场景,模拟有如下数据:

     1 raw_X=[[1,2],
     2        [2.8,2.5],
     3        [4,3.2],
     4        [2,1.5],
     5        [6,7.8],
     6        [8,5],
     7        [9,7],
     8        [7,8.5],
     9        [10,9.7],       
    10       ]
    11 raw_y=[0,0,0,0,1,1,1,1,1]
    12 X_train=np.array(raw_X)
    13 y_train=np.array(raw_y)

    现在有一个数据x(设置为绿色的点)进来了,要判断这个数据是属于哪一类的:

    1 x=np.array([7.5,6.5])
    2 plt.scatter(X_train[y_train==0,0],X_train[y_train==0,1])
    3 plt.scatter(X_train[y_train==1,0],X_train[y_train==1,1],color='r')
    4 plt.scatter(x[0],x[1],color='g')
    5 plt.show()

    那么,按照KNN的思路就需求,求出这个里面,所有点离这个绿色点的距离了,看这个绿色的点离哪些是最近的。

    那么,根据欧拉距离,一般程序员就可以写出这样的代码了:

    1 from math import sqrt
    2 distances=[]
    3 for x_train in X_train:
    4     d=sqrt(np.sum(x_train-x)**2)
    5     distances.append(d)

    当然,根据欧拉距离,不一般的程序员是会这么写:

    distances=[sqrt(np.sum(x_train-x)**2) for x_train in X_train]

    而结果distances都会是:

    [11.0, 8.7, 6.8, 10.5, 0.20000000000000018, 1.0, 2.0, 1.5, 5.699999999999999]
    

    接着,算出距离最近元素的索引,进而拿到距离最近的值:

    1 nearest=np.argsort(distances)
    2 topK_y=[y_train[i] for neighbor in nearest[:5]]
    3 from collections import Counter
    4 votes=Counter(topK_y)
    5 predict_y=votes.most_common(1)[0][0]
    6 predict_y

    结果明显是1。

      

      

  • 相关阅读:
    Memcached缓存在.Net 中的使用(memcacheddotnet)
    转载 单目摄像机标定说明
    图像的上采样 下采样
    VOC数据集 目标检测
    tensorflow用pretrained-model做retrain
    图像分类基础
    TensorRT学习总结
    Jetson Nano Developer Kit
    pytorch深度学习60分钟闪电战
    系统检测工具ROSWTF
  • 原文地址:https://www.cnblogs.com/anmutu/p/ml.html
Copyright © 2011-2022 走看看