zoukankan      html  css  js  c++  java
  • 机器学习实战SKlearn之KNN手写数字识别

     1 # -*- coding: UTF-8 -*-
     2 import numpy as np
     3 import operator
     4 from os import listdir
     5 from sklearn.neighbors import KNeighborsClassifier as kNN
     6 
     7 """
     8 函数说明:将32x32的二进制图像转换为1x1024向量。
     9 
    10 Parameters:
    11     filename - 文件名
    12 Returns:
    13     returnVect - 返回二进制图像的1x1024向量
    14 """
    15 def img2vector(filename):
    16     #创建1x1024零向量,np.zeros((a,b)),a代表第一层括号(最外层)看向元素个数,b代表第二层括号内向量元素个数即内层
    17     #注意这里zeros(1,1024)产生的是二维数组[[0,0,....]]
    18     return_vect = np.zeros((1, 1024))  # 数组的索引都是从0开始,但是size/shape中都是实际数目。
    19     #打开文件
    20     fr = open(filename)
    21     #按行读取
    22     for i in range(32):
    23         #读一行数据
    24         lineStr = fr.readline()
    25         #每一行的前32个元素依次添加到returnVect中
    26         for j in range(32):    #range(32):0,1,2,....31
    27             return_vect[0,32*i+j] = int(lineStr[j]) #0:向量第一层内的第一个分向量
    28     #返回转换后的1x1024向量
    29     return return_vect
    30 
    31 """
    32 函数说明:手写数字分类测试
    33 """
    34 def handwriting_ClassTest():
    35     #测试集的Labels
    36     hwLabels = []
    37     #返回trainingDigits目录下的文件名
    38     trainingFileList = listdir('C:/Users/Administrator/Desktop/data/Ch02/digits/trainingDigits')
    39     #返回文件夹下文件的个数
    40     m = len(trainingFileList)
    41     #初始化训练的Mat矩阵,测试集
    42     trainingMat = np.zeros((m, 1024))
    43     #从文件名中解析出训练集的类别
    44     for i in range(m):
    45         #获得文件的名字
    46         fileName_Str = trainingFileList[i]
    47         #获得分类的数字
    48         classNumber = int(fileName_Str.split('_')[0])
    49         #将获得的类别添加到hwLabels中
    50         hwLabels.append(classNumber)
    51         #将每一个文件的1x1024数据存储到trainingMat矩阵中
    52         trainingMat[i,:] = img2vector('C:/Users/Administrator/Desktop/data/Ch02/digits/trainingDigits/%s' % (fileName_Str))
    53     #构建kNN分类器
    54     neigh = kNN(n_neighbors = 3, algorithm = 'auto')
    55     #拟合模型, trainingMat为测试矩阵,hwLabels为对应的标签
    56     neigh.fit(trainingMat, hwLabels)
    57     #返回testDigits目录下的文件列表
    58     testFileList = listdir('C:/Users/Administrator/Desktop/data/Ch02/digits/testDigits')
    59     #错误检测计数
    60     errorCount = 0.0
    61     #测试数据的数量
    62     mTest = len(testFileList)
    63     #从文件中解析出测试集的类别并进行分类测试
    64     for i in range(mTest):
    65         #获得文件的名字
    66         fileName_Str = testFileList[i]
    67         #获得分类的数字
    68         classNumber = int(fileName_Str.split('_')[0])
    69         #获得测试集的1x1024向量,用于训练
    70         vector_UnderTest = img2vector('C:/Users/Administrator/Desktop/data/Ch02/digits/testDigits/%s' % (fileName_Str))
    71         #获得预测结果
    72         # classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)
    73         classifierResult = neigh.predict(vector_UnderTest)
    74         print("分类返回结果为%d\t真实结果为%d" % (classifierResult, classNumber))
    75         if(classifierResult != classNumber):
    76             errorCount += 1.0
    77     print("总共错了%d个数据\n错误率为%f%%" % (errorCount, errorCount/mTest * 100))
    78 
    79 
    80 """
    81 函数说明:main函数
    82 """
    83 if __name__ == '__main__':
    84     handwriting_ClassTest()
  • 相关阅读:
    emacs 集成astyle
    git reflog
    rpm 打包的时候 不进行strip
    gmock
    如何对正在运行的进程,进行heap profile
    linux性能压测工具
    默认宏定义
    gdb fabs错误输出
    基于Clang的缓存型C++编译器Zapcc
    grep 多行 正则匹配
  • 原文地址:https://www.cnblogs.com/Henry-ZHAO/p/12725329.html
Copyright © 2011-2022 走看看