zoukankan      html  css  js  c++  java
  • python 实现 KNN 分类器——手写识别

    1 算法概述

    1.1 优劣

    优点:进度高,对异常值不敏感,无数据输入假定

    缺点:计算复杂度高,空间复杂度高

    应用:主要用于文本分类,相似推荐

    适用数据范围:数值型和标称型

    1.2 算法伪代码

    (1)计算已知类别数据集中的点与当前点的距离

    (2)按照距离递增次序排序,选取与当前点距离最小的 k 个点

    (3)确定前 k 个点所在类别的出现频率

    (4)返回前 k 个点出现频率最高的类别作为当前点的预测分类

    2 手写识别

    2.1 概念

         指在手写设备上书写时产生的轨迹信息转化为具体字码,本篇博客重点非搭建手写识别系统,而是帮助理解 KNN。

    2.2 编程实现步骤

    (1)将图片(txt 文本)转为一个向量,即32*32的数组转化为1*1024的数组(特征向量)

    (2)将特征向量转化为矩阵

    (3)计算每个测试集中的特征向量和训练集中的特征向量的距离,选取距离较小的前 k 个,该特征向量对应的图片数字为 k 个图片中出现次数最多的那个数字。

    2.3 具体代码

    (1)转化为1*1024特征向量

    def img2vector(filename):
        returnVect = zeros((1,1024))
        fr = open(filename)
        for i in range(32):
            lineStr = fr.readline()
            for j in range(32):
                returnVect[0,32*i+j] = int(lineStr[j])
        return returnVect

    (2)计算欧式距离,返回测试图片的类别

    def classify0(inX, dataSet, labels, k):
        dataSetSize = dataSet.shape[0]                  
        diffMat = tile(inX, (dataSetSize,1)) - dataSet   # shape[0]得出dataSet的行数,即样本个数   
        sqDiffMat = diffMat**2                           # tile(A,(m,n))将数组A作为元素构造m行n列的数组
        sqDistances = sqDiffMat.sum(axis=1)                  
        distances = sqDistances**0.5
        sortedDistIndicies = distances.argsort()         # array.argsort(),得到每个元素的排序序号   
        classCount={}                                    # sortedDistIndicies[0]表示排序后排在第一个的那个数在原来数组中的下标  
        for i in range(k):
            voteIlabel = labels[sortedDistIndicies[i]]
            classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1 # 从字典中获取key对应的value,没有key的话返回0
        # sorted()函数,按照第二个元素即value的次序逆向(reverse=True)排序  
        sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
        return sortedClassCount[0][0]

    (3)将每个向量合成矩阵,并对测试集中的每个样本分类

    def handwritingClassTest():
        hwLabels = []
        # os模块中的listdir('str')可以读取目录str下的所有文件名,返回一个字符串列表  
        trainingFileList = listdir('trainingDigits')          
        m = len(trainingFileList)
        trainingMat = zeros((m,1024))
        for i in range(m):
            fileNameStr = trainingFileList[i]                  
            fileStr = fileNameStr.split('.')[0]                
            classNumStr = int(fileStr.split('_')[0])          
            hwLabels.append(classNumStr)
            trainingMat[i,:] = img2vector('trainingDigits/%s' % fileNameStr)
        
        # 逐一读取测试图片,同时将其分类 
        testFileList = listdir('testDigits')       
        errorCount = 0.0
        mTest = len(testFileList)
        for i in range(mTest):
            fileNameStr = testFileList[i]
            fileStr = fileNameStr.split('.')[0]     
            classNumStr = int(fileStr.split('_')[0])
            vectorUnderTest = img2vector('testDigits/%s' % fileNameStr)
            classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)
            print("the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr))
            if (classifierResult != classNumStr): 
                errorCount += 1.0
        print("
    the total number of errors is: %d" % errorCount)
        print("
    the total error rate is: %f" % (errorCount/float(mTest)))

    3 运行结果

        进入模块所在的文件夹,打开  Spyder,运行模块。然后在  Ipython 控制台输入以下代码:

    import KNN
    KNN.handwritingClassTest()

        得到以下结果:


        在 k = 3 的时候,错误率为1.2%。


    参考资料:

    《机器学习实战》



  • 相关阅读:
    vue+node.js+webpack开发微信公众号功能填坑——组件按需引入
    myeclipse打开jsp页面慢或者卡死
    myeclipse自动添加注释
    解决java.lang.NoSuchMethodError:org.joda.time.DateTime.withTimeAtStartOfDay() Lorg/joda/time/DateTime
    Echarts柱状图实现不同颜色渐变色
    《Python学习手册 第五版》 -第38章 被管理的属性
    《Python学习手册 第五版》 -第37章 Unicode和字节串
    《Python学习手册 第五版》 -第36章 异常的设计
    《Python学习手册 第五版》 -第35章 异常对象
    《Python学习手册 第五版》 -第34章 异常编写细节
  • 原文地址:https://www.cnblogs.com/mtcnn/p/9411607.html
Copyright © 2011-2022 走看看