zoukankan      html  css  js  c++  java
  • KNN算法——python实现

    二、Python实现

           对于机器学习而已,Python需要额外安装三件宝,分别是Numpy,scipy和Matplotlib。前两者用于数值计算,后者用于画图。安装很简单,直接到各自的官网下载回来安装即可。安装程序会自动搜索我们的python版本和目录,然后安装到python支持的搜索路径下。反正就python和这三个插件都默认安装就没问题了。

           另外,如果我们需要添加我们的脚本目录进Python的目录(这样Python的命令行就可以直接import),可以在系统环境变量中添加:PYTHONPATH环境变量,值为我们的路径,例如:E:PythonMachine Learning in Action

     

    2.1、kNN基础实践

           一般实现一个算法后,我们需要先用一个很小的数据库来测试它的正确性,否则一下子给个大数据给它,它也很难消化,而且还不利于我们分析代码的有效性。

          首先,我们新建一个kNN.py脚本文件,文件里面包含两个函数,一个用来生成小数据库,一个实现kNN分类算法。代码如下:

     

    [python] view plain copy
     
     在CODE上查看代码片派生到我的代码片
    1. #########################################  
    2. # kNN: k Nearest Neighbors  
    3.   
    4. # Input:      newInput: vector to compare to existing dataset (1xN)  
    5. #             dataSet:  size m data set of known vectors (NxM)  
    6. #             labels:   data set labels (1xM vector)  
    7. #             k:        number of neighbors to use for comparison   
    8.               
    9. # Output:     the most popular class label  
    10. #########################################  
    11.   
    12. from numpy import *  
    13. import operator  
    14.   
    15. # create a dataset which contains 4 samples with 2 classes  
    16. def createDataSet():  
    17.     # create a matrix: each row as a sample  
    18.     group = array([[1.0, 0.9], [1.0, 1.0], [0.1, 0.2], [0.0, 0.1]])  
    19.     labels = ['A', 'A', 'B', 'B'] # four samples and two classes  
    20.     return group, labels  
    21.   
    22. # classify using kNN  
    23. def kNNClassify(newInput, dataSet, labels, k):  
    24.     numSamples = dataSet.shape[0] # shape[0] stands for the num of row  
    25.   
    26.     ## step 1: calculate Euclidean distance  
    27.     # tile(A, reps): Construct an array by repeating A reps times  
    28.     # the following copy numSamples rows for dataSet  
    29.     diff = tile(newInput, (numSamples, 1)) - dataSet # Subtract element-wise  
    30.     squaredDiff = diff ** # squared for the subtract  
    31.     squaredDist = sum(squaredDiff, axis = 1) # sum is performed by row  
    32.     distance = squaredDist ** 0.5  
    33.   
    34.     ## step 2: sort the distance  
    35.     # argsort() returns the indices that would sort an array in a ascending order  
    36.     sortedDistIndices = argsort(distance)  
    37.   
    38.     classCount = {} # define a dictionary (can be append element)  
    39.     for i in xrange(k):  
    40.         ## step 3: choose the min k distance  
    41.         voteLabel = labels[sortedDistIndices[i]]  
    42.   
    43.         ## step 4: count the times labels occur  
    44.         # when the key voteLabel is not in dictionary classCount, get()  
    45.         # will return 0  
    46.         classCount[voteLabel] = classCount.get(voteLabel, 0) + 1  
    47.   
    48.     ## step 5: the max voted class will return  
    49.     maxCount = 0  
    50.     for key, value in classCount.items():  
    51.         if value > maxCount:  
    52.             maxCount = value  
    53.             maxIndex = key  
    54.   
    55.     return maxIndex   

     

           然后我们在命令行中这样测试即可:

     

    [python] view plain copy
     
     在CODE上查看代码片派生到我的代码片
    1. import kNN  
    2. from numpy import *   
    3.   
    4. dataSet, labels = kNN.createDataSet()  
    5.   
    6. testX = array([1.2, 1.0])  
    7. k = 3  
    8. outputLabel = kNN.kNNClassify(testX, dataSet, labels, 3)  
    9. print "Your input is:", testX, "and classified to class: ", outputLabel  
    10.   
    11. testX = array([0.1, 0.3])  
    12. outputLabel = kNN.kNNClassify(testX, dataSet, labels, 3)  
    13. print "Your input is:", testX, "and classified to class: ", outputLabel  

     

           这时候会输出:

     

    [python] view plain copy
     
     在CODE上查看代码片派生到我的代码片
    1. Your input is: [ 1.2  1.0] and classified to class:  A  
    2. Your input is: [ 0.1  0.3] and classified to class:  B  

     

    2.2、kNN进阶

           这里我们用kNN来分类一个大点的数据库,包括数据维度比较大和样本数比较多的数据库。这里我们用到一个手写数字的数据库,可以到这里下载。这个数据库包括数字0-9的手写体。每个数字大约有200个样本。每个样本保持在一个txt文件中。手写体图像本身的大小是32x32的二值图,转换到txt文件保存后,内容也是32x32个数字,0或者1,如下:

     

           数据库解压后有两个目录:目录trainingDigits存放的是大约2000个训练数据,testDigits存放大约900个测试数据。

            这里我们还是新建一个kNN.py脚本文件,文件里面包含四个函数,一个用来生成将每个样本的txt文件转换为对应的一个向量,一个用来加载整个数据库,一个实现kNN分类算法。最后就是实现这个加载,测试的函数。

     

    [python] view plain copy
     
     在CODE上查看代码片派生到我的代码片
    1. #########################################  
    2. # kNN: k Nearest Neighbors  
    3.   
    4. # Input:      inX: vector to compare to existing dataset (1xN)  
    5. #             dataSet: size m data set of known vectors (NxM)  
    6. #             labels: data set labels (1xM vector)  
    7. #             k: number of neighbors to use for comparison   
    8.               
    9. # Output:     the most popular class label  
    10. #########################################  
    11.   
    12. from numpy import *  
    13. import operator  
    14. import os  
    15.   
    16.   
    17. # classify using kNN  
    18. def kNNClassify(newInput, dataSet, labels, k):  
    19.     numSamples = dataSet.shape[0] # shape[0] stands for the num of row  
    20.   
    21.     ## step 1: calculate Euclidean distance  
    22.     # tile(A, reps): Construct an array by repeating A reps times  
    23.     # the following copy numSamples rows for dataSet  
    24.     diff = tile(newInput, (numSamples, 1)) - dataSet # Subtract element-wise  
    25.     squaredDiff = diff ** # squared for the subtract  
    26.     squaredDist = sum(squaredDiff, axis = 1) # sum is performed by row  
    27.     distance = squaredDist ** 0.5  
    28.   
    29.     ## step 2: sort the distance  
    30.     # argsort() returns the indices that would sort an array in a ascending order  
    31.     sortedDistIndices = argsort(distance)  
    32.   
    33.     classCount = {} # define a dictionary (can be append element)  
    34.     for i in xrange(k):  
    35.         ## step 3: choose the min k distance  
    36.         voteLabel = labels[sortedDistIndices[i]]  
    37.   
    38.         ## step 4: count the times labels occur  
    39.         # when the key voteLabel is not in dictionary classCount, get()  
    40.         # will return 0  
    41.         classCount[voteLabel] = classCount.get(voteLabel, 0) + 1  
    42.   
    43.     ## step 5: the max voted class will return  
    44.     maxCount = 0  
    45.     for key, value in classCount.items():  
    46.         if value > maxCount:  
    47.             maxCount = value  
    48.             maxIndex = key  
    49.   
    50.     return maxIndex   
    51.   
    52. # convert image to vector  
    53. def  img2vector(filename):  
    54.     rows = 32  
    55.     cols = 32  
    56.     imgVector = zeros((1, rows * cols))   
    57.     fileIn = open(filename)  
    58.     for row in xrange(rows):  
    59.         lineStr = fileIn.readline()  
    60.         for col in xrange(cols):  
    61.             imgVector[0, row * 32 + col] = int(lineStr[col])  
    62.   
    63.     return imgVector  
    64.   
    65. # load dataSet  
    66. def loadDataSet():  
    67.     ## step 1: Getting training set  
    68.     print "---Getting training set..."  
    69.     dataSetDir = 'E:/Python/Machine Learning in Action/'  
    70.     trainingFileList = os.listdir(dataSetDir + 'trainingDigits') # load the training set  
    71.     numSamples = len(trainingFileList)  
    72.   
    73.     train_x = zeros((numSamples, 1024))  
    74.     train_y = []  
    75.     for i in xrange(numSamples):  
    76.         filename = trainingFileList[i]  
    77.   
    78.         # get train_x  
    79.         train_x[i, :] = img2vector(dataSetDir + 'trainingDigits/%s' % filename)   
    80.   
    81.         # get label from file name such as "1_18.txt"  
    82.         label = int(filename.split('_')[0]) # return 1  
    83.         train_y.append(label)  
    84.   
    85.     ## step 2: Getting testing set  
    86.     print "---Getting testing set..."  
    87.     testingFileList = os.listdir(dataSetDir + 'testDigits') # load the testing set  
    88.     numSamples = len(testingFileList)  
    89.     test_x = zeros((numSamples, 1024))  
    90.     test_y = []  
    91.     for i in xrange(numSamples):  
    92.         filename = testingFileList[i]  
    93.   
    94.         # get train_x  
    95.         test_x[i, :] = img2vector(dataSetDir + 'testDigits/%s' % filename)   
    96.   
    97.         # get label from file name such as "1_18.txt"  
    98.         label = int(filename.split('_')[0]) # return 1  
    99.         test_y.append(label)  
    100.   
    101.     return train_x, train_y, test_x, test_y  
    102.   
    103. # test hand writing class  
    104. def testHandWritingClass():  
    105.     ## step 1: load data  
    106.     print "step 1: load data..."  
    107.     train_x, train_y, test_x, test_y = loadDataSet()  
    108.   
    109.     ## step 2: training...  
    110.     print "step 2: training..."  
    111.     pass  
    112.   
    113.     ## step 3: testing  
    114.     print "step 3: testing..."  
    115.     numTestSamples = test_x.shape[0]  
    116.     matchCount = 0  
    117.     for i in xrange(numTestSamples):  
    118.         predict = kNNClassify(test_x[i], train_x, train_y, 3)  
    119.         if predict == test_y[i]:  
    120.             matchCount += 1  
    121.     accuracy = float(matchCount) / numTestSamples  
    122.   
    123.     ## step 4: show the result  
    124.     print "step 4: show the result..."  
    125.     print 'The classify accuracy is: %.2f%%' % (accuracy * 100)  

     

           测试非常简单,只需要在命令行中输入:

     

    [python] view plain copy
     
     在CODE上查看代码片派生到我的代码片
    1. import kNN  
    2. kNN.testHandWritingClass()  

     

           输出结果如下:

     

    [python] view plain copy
     
     在CODE上查看代码片派生到我的代码片
      1. step 1: load data...  
      2. ---Getting training set...  
      3. ---Getting testing set...  
      4. step 2: training...  
      5. step 3: testing...  
      6. step 4: show the result...  
      7. The classify accuracy is: 98.84%  

    个人修改一些注释:

    # -*- coding: utf-8 -*-  
    """
    KNN: K Nearest Neighbors Input: newInput:vector to compare to existing dataset(1xN) dataSet:size m data set of known vectors(NxM) labels:data set labels(1xM vector) k:number of neighbors to use for comparison Output: the most popular class labels N为数据的维度 M为数据个数 """ from numpy import * import operator #create a dataset which contains 4 samples with 2 classes def createDataSet(): #create a matrix:each row as a sample group = array([[1.0,0.9],[1.0,1.0],[0.1,0.2],[0.0,0.1]]) #four samples and two classes labels = ['A','A','B','B'] return group,labels #classify using KNN def KNNClassify(newInput, dataSet, labels, k): numSamples = dataSet.shape[0] #shape[0] stands for the num of row 即是m ##step 1:calculate Euclidean distance #tile(A,reps):Construct an array by repeating A reps times #the following copy numSamples rows for dataSet diff = tile(newInput,(numSamples,1)) - dataSet #Subtract element-wise squaredDiff = diff ** 2 #squared for the subtract squaredDist = sum(squaredDiff, axis = 1) #sum is performed by row distance = squaredDist ** 0.5 ##step 2:sort the distance #argsort() return the indices that would sort an array in a ascending order sortedDistIndices = argsort(distance) classCount = {} #define a dictionary (can be append element) for i in xrange(k): ##step 3:choose the min k diatance voteLabel = labels[sortedDistIndices[i]] ##step 4:count the times labels occur #when the key voteLabel is not in dictionary classCount,get() #will return 0 #按classCount字典的第2个元素(即类别出现的次数)从大到小排序 #即classCount是一个字典,key是类型,value是该类型出现的次数,通过for循环遍历来计算 classCount[voteLabel] = classCount.get(voteLabel,0) + 1 ##step 5:the max voted class will return #eg:假设classCount={'A':3,'B':2} maxCount = 0 for key,value in classCount.items(): if value > maxCount: maxCount = value maxIndex = key return maxIndex
  • 相关阅读:
    PL/SQL跨库查询数据
    oracle 两个时间相减
    导出Excel格式数据
    Java导出pdf文件数据
    $.ajax相关用法
    oracle 删除掉重复数据只保留一条
    常用Oracle操作语句
    JS请求服务器,并返回信息,请求过程中不需要跳转页面
    tomcat部署web项目的3中方法
    Date()日期转换和简单计算
  • 原文地址:https://www.cnblogs.com/GDUT-xiang/p/5701446.html
Copyright © 2011-2022 走看看