zoukankan      html  css  js  c++  java
  • kNN算法python实现和简单数字识别

    kNN算法

    算法优缺点:

    • 优点:精度高、对异常值不敏感、无输入数据假定
    • 缺点:时间复杂度和空间复杂度都很高
    • 适用数据范围:数值型和标称型

    算法的思路:

    KNN算法(全称K最近邻算法),算法的思想很简单,简单的说就是物以类聚,也就是说我们从一堆已知的训练集中找出k个与目标最靠近的,然后看他们中最多的分类是哪个,就以这个为依据分类。 

    函数解析:

    库函数

    • tile()

      tile(A,n)就是将A重复n次

    a = np.array([0, 1, 2])
    np.tile(a, 2)
    array([0, 1, 2, 0, 1, 2])
    np.tile(a, (2, 2))
    array([[0, 1, 2, 0, 1, 2],[0, 1, 2, 0, 1, 2]])
    np.tile(a, (2, 1, 2))
    array([[[0, 1, 2, 0, 1, 2]],[[0, 1, 2, 0, 1, 2]]])
    b = np.array([[1, 2], [3, 4]])
    np.tile(b, 2)
    array([[1, 2, 1, 2],[3, 4, 3, 4]])
    np.tile(b, (2, 1))
    array([[1, 2],[3, 4],[1, 2],[3, 4]])`

    自己实现的函数

    createDataSet()生成测试数组
    kNNclassify(inputX, dataSet, labels, k)分类函数

    • inputX 输入的参数
    • dataSet 训练集
    • labels 训练集的标号
    • k 最近邻的数目
      1.  1 #coding=utf-8
         2 from numpy import *
         3 import operator
         4 
         5 def createDataSet():
         6     group = array([[1.0, 0.9], [1.0, 1.0], [0.1, 0.2], [0.0, 0.1]])
         7     labels = ['A','A','B','B']
         8     return group,labels
         9 #inputX表示输入向量(也就是我们要判断它属于哪一类的)
        10 #dataSet表示训练样本
        11 #label表示训练样本的标签
        12 #k是最近邻的参数,选最近k个
        13 def kNNclassify(inputX, dataSet, labels, k):
        14     dataSetSize = dataSet.shape[0]#计算有几个训练数据
        15     #开始计算欧几里得距离
        16     diffMat = tile(inputX, (dataSetSize,1)) - dataSet
        17     
        18     sqDiffMat = diffMat ** 2
        19     sqDistances = sqDiffMat.sum(axis=1)#矩阵每一行向量相加
        20     distances = sqDistances ** 0.5
        21     #欧几里得距离计算完毕
        22     sortedDistance = distances.argsort()
        23     classCount = {}
        24     for i in xrange(k):
        25         voteLabel = labels[sortedDistance[i]]
        26         classCount[voteLabel] = classCount.get(voteLabel,0) + 1
        27     res = max(classCount)
        28     return res
        29 
        30 def main():
        31     group,labels = createDataSet()
        32     t = kNNclassify([0,0],group,labels,3)
        33     print t
        34     
        35 if __name__=='__main__':
        36     main()
        37             

    kNN应用实例

    手写识别系统的实现

    数据集:

    两个数据集:training和test。分类的标号在文件名中。像素32*32的。数据大概这个样子:

    方法:

    kNN的使用,不过这个距离算起来比较复杂(1024个特征),主要是要处理如何读取数据这个问题的,比较方面直接调用就可以了。

    速度:

    速度还是比较慢的,这里数据集是:training 2000+,test 900+(i5的CPU)

    k=3的时候要32s+

    1.  1 #coding=utf-8
       2 from numpy import *
       3 import operator
       4 import os
       5 import time
       6 
       7 def createDataSet():
       8     group = array([[1.0, 0.9], [1.0, 1.0], [0.1, 0.2], [0.0, 0.1]])
       9     labels = ['A','A','B','B']
      10     return group,labels
      11 #inputX表示输入向量(也就是我们要判断它属于哪一类的)
      12 #dataSet表示训练样本
      13 #label表示训练样本的标签
      14 #k是最近邻的参数,选最近k个
      15 def kNNclassify(inputX, dataSet, labels, k):
      16     dataSetSize = dataSet.shape[0]#计算有几个训练数据
      17     #开始计算欧几里得距离
      18     diffMat = tile(inputX, (dataSetSize,1)) - dataSet
      19     #diffMat = inputX.repeat(dataSetSize, aixs=1) - dataSet
      20     sqDiffMat = diffMat ** 2
      21     sqDistances = sqDiffMat.sum(axis=1)#矩阵每一行向量相加
      22     distances = sqDistances ** 0.5
      23     #欧几里得距离计算完毕
      24     sortedDistance = distances.argsort()
      25     classCount = {}
      26     for i in xrange(k):
      27         voteLabel = labels[sortedDistance[i]]
      28         classCount[voteLabel] = classCount.get(voteLabel,0) + 1
      29     res = max(classCount)
      30     return res
      31 
      32 def img2vec(filename):
      33     returnVec = zeros((1,1024))
      34     fr = open(filename)
      35     for i in range(32):
      36         lineStr = fr.readline()
      37         for j in range(32):
      38             returnVec[0,32*i+j] = int(lineStr[j])
      39     return returnVec
      40     
      41 def handwritingClassTest(trainingFloder,testFloder,K):
      42     hwLabels = []
      43     trainingFileList = os.listdir(trainingFloder)
      44     m = len(trainingFileList)
      45     trainingMat = zeros((m,1024))
      46     for i in range(m):
      47         fileName = trainingFileList[i]
      48         fileStr = fileName.split('.')[0]
      49         classNumStr = int(fileStr.split('_')[0])
      50         hwLabels.append(classNumStr)
      51         trainingMat[i,:] = img2vec(trainingFloder+'/'+fileName)
      52     testFileList = os.listdir(testFloder)
      53     errorCount = 0.0
      54     mTest = len(testFileList)
      55     for i in range(mTest):
      56         fileName = testFileList[i]
      57         fileStr = fileName.split('.')[0]
      58         classNumStr = int(fileStr.split('_')[0])
      59         vectorUnderTest = img2vec(testFloder+'/'+fileName)
      60         classifierResult = kNNclassify(vectorUnderTest, trainingMat, hwLabels, K)
      61         #print classifierResult,' ',classNumStr
      62         if classifierResult != classNumStr:
      63             errorCount +=1
      64     print 'tatal error ',errorCount
      65     print 'error rate',errorCount/mTest
      66         
      67 def main():
      68     t1 = time.clock()
      69     handwritingClassTest('trainingDigits','testDigits',3)
      70     t2 = time.clock()
      71     print 'execute ',t2-t1
      72 if __name__=='__main__':
      73     main()
      74             





  • 相关阅读:
    idea的使用和安装破解 2019.2
    get请求和post请求的区别
    MySQL-事务
    MySQL-mysql的查询练习
    MySQL-mysql的多表查询
    CodeForces
    2018宁夏邀请赛网赛 I. Reversion Count(java练习题)
    HDU
    Codeforces Round #479 (Div. 3)解题报告
    nyoj 1274信道安全 第九届河南省赛(SPFA)
  • 原文地址:https://www.cnblogs.com/MrLJC/p/4098011.html
Copyright © 2011-2022 走看看