zoukankan      html  css  js  c++  java
  • 机器学习python实战----手写数字识别

      看了上一篇内容之后,相信对K近邻算法有了一个清晰的认识,今天的内容——手写数字识别是对上一篇内容的延续,这里也是为了自己能更熟练的掌握k-NN算法。

    我们有大约2000个训练样本和1000个左右测试样本,训练样本所在的文件夹是trainingDigits,测试样本所在的文件夹是testDigits。文本文件中是0~9的数字,但是是用二值图表示出来的,如图。我们要做的就是使用训练样本训练模型,并用测试样本来检测模型的性能。

    首先,我们需要将文本文件中的内容转化为向量,因为图片大小是32*32,所以我们可以将其转化为1*1024的向量。具体代码实现如下:

    def img2vector(filename):
        imgVec = zeros((1,1024))
        file = open(filename)
        for i in range(32):
            lines = file.readline()
            for j in range(32):
                imgVec[0,32*i+j] = lines[j]
        return imgVec
        

    实现了图片到向量的转化之后,我们就可以对测试文件中的内容进行识别了。这里的识别我们可以使用上一篇中的自定义函数classify0,这个函数的第一个参数是测试向量,第二个参数是训练数据集,第三个参数是训练集的标签。所以,我们首先需要将训练数据集转化为(1934*1024)的矩阵,1934这里是训练集的组数即trainingDigits目录下的文件数 ,其对应的标签转化为(1*1934)的向量。之后要编写的代码就是对测试数据集中的每个文本文件进行识别,也就是需要将每个文件都转化成一个(1*1024)的向量,再传入classify0函数的第一个形参。整体代码如下:

    def handWriteNumClassTest():
        NumLabels = []
        TrainingDirfile = listdir(r'D:ipython
    um_recognize	rainingDigits')#文件目录
        L = len(TrainingDirfile)   #该目录中有多少文件
        TrainMat = zeros((L,1024))
        for i in range(L):
            file_n = TrainingDirfile[i]
            fileName = file_n.split('.')[0]
            ClassName = int(file_n.split('_')[0])
            NumLabels.append(ClassName)
            TrainMat[i,:] = img2vector(r'D:ipython
    um_recognize	rainingDigits\%s'%file_n)
        TestfileDir = listdir(r'D:ipython
    um_recognize	estDigits')
        error_cnt = 0.0
        M = len(TestfileDir)
        for j in range(M):
            Testfile = TestfileDir[j]
            TestfileName = Testfile.split('.')[0]
            TestClassName = int(Testfile.split('_')[0])
            TestVector = img2vector(r'D:ipython
    um_recognize	estDigits\%s'%Testfile)
            result = classify0(TestVector,TrainMat,NumLabels,3)
            print('the result is %d,the real answer is %d
    '%(result,TestClassName))
            if result!=TestClassName:
                error_cnt+=1
        print('the total num of errors is %f
    '%error_cnt)
        print('the error rate is %f
    '%(error_cnt/float(M)))

    这里需要首先导入listdir方法,from os import listdir,它可以列出给定目录的文件名。对于测试的每个文件,如果识别的分类结果跟真实结果不一样,则错误数+1,最终用错误数/测试总数 来表示该模型的性能。下面给出结果

    这里测试的总共946个项目中,一共有10个出现了错误,出错率为1%,这个性能还是可以接受的。有了上一篇内容的理解,这篇就简单多了吧!

    训练数据集和测试集文件下载:http://pan.baidu.com/s/1hsMntJ2

  • 相关阅读:
    原型模式
    浅复制和深复制
    适配器模式
    外观模式
    模板方法
    建造者模式
    代理模式
    Centos7重新安装yum
    关于mongodb创建索引的一些经验总结(转)
    MongoDB查询语句(转)
  • 原文地址:https://www.cnblogs.com/kl2blog/p/7751006.html
Copyright © 2011-2022 走看看