zoukankan      html  css  js  c++  java
  • 机器学习--kNN算法识别手写字母

     本文主要是用kNN算法对字母图片进行特征提取,分类识别。内容如下:

    1. kNN算法及相关Python模块介绍
    2. 对字母图片进行特征提取
    3. kNN算法实现
    4.  kNN算法分析

    一、kNN算法介绍

        K近邻(kNN,k-NearestNeighbor)分类算法是机器学习算法中最简单的方法之一。所谓K近邻,就是k个最近的邻居的意思,说的是每个样本都可以用它最接近的k个邻居来代表。我们将样本分为训练样本和测试样本。对一个测试样本 t  进行分类,kNN的做法是先计算样本 t  到所有训练样本的欧氏距离,然后从中找出k个距离最短的训练样本,用这k个训练样本中出现次数最多的类别表示样本 t 的类别。

        欧式距离的计算公式:

          假设每个样本有两个特征值,如 A :(a1,b1)B:(a2,b2) 则AB的欧式距离为        

            

        举个例子:根据下图前四位同学的成绩和等级,预测第五位小白同学的等级

                                            

    我们可以看出:语文和数学成绩是一个学生的特征,等级是一个学生的类别。

    前四位同学是训练样本,第五位同学是测试样本。我们现在用kNN算法来预测第五位同学的等级,k取3。

    按照上面欧式距离公式我们可以计算

    d(5-1)== 7          d(5-2)== 30      

    d(5-3)== 6          d(5-4)== 19.2

    因为 k 取 3,所以我们寻找3个距离最近的样本,即编号为3,1,4的同学,他们的等级分别是 B,B,A。 这三个样本的分类中,出现了2次B,一次A,B出现次数最多,所以5号同学的等级可能为B

    常用Python模块

      NumPy:NumPy是Python的一种开源的数值计算扩展。这种工具可用来存储和处理大型矩阵,比Python自身的嵌套列表结构要高效的多。

      PIL:Python Imaging Library,是Python平台事实上的图像处理标准库,功能非常强大,API也简单易用。但PIL包主要针对Python2,不兼容Python3,所以在Python3中使用Pillow,后者是大牛根据PIL移植过来的,两者用法相同。

      上面两个Python库都可以通过pip进行安装。

    pip3 install [name]

      还有就是Python 自带标准库:shutil模块提供了大量的文件的高级操作,特别针对文件拷贝和删除,主要功能为目录和文件操作以及压缩操作。operator模块是Python 的运算符库,os 模块是Python的系统的和操作系统相关的函数库。

    二、对图片进行特征提取

    1、采集手写字母的图片素材

      有许多提供机器学习数据集的网站,如知乎上的整理 https://www.zhihu.com/question/63383992/answer/222718972  我搜集到的手写字母图片资源如下 链接:https://pan.baidu.com/s/1pM329fl  密码:i725   其中by_class.zip 压缩包是已经分类好的图片样本,可以直接下载使用

    2、提取图片素材的特征

      最简单的做法是将图片转换为由0 和1 组成的txt 文件,如

                 

    转换代码如下:

     1 import os
     2 import shutil
     3 from PIL import Image
     4 
     5 
     6 # image_file_prefix  png图片所在的文件夹
     7 # file_name png      png图片的名字
     8 # txt_path_prefix    转换后txt 文件所在的文件夹
     9 def generate_txt_image(image_file_prefix, file_name, txt_path_prefix):
    10     """将图片处理成只有0 和 1 的txt 文件"""
    11     # 将png图片转换成二值图并截取四周多余空白部分
    12     image_path = os.path.join(image_file_prefix, file_name)
    13     # convert('L') 将图片转为灰度图 convert('1') 将图片转为二值图
    14     img = Image.open(image_path, 'r').convert('1').crop((32, 32, 96, 96))
    15     # 指定转换后的宽 高
    16     width, height = 32, 32
    17    img.thumbnail((width, height), Image.ANTIALIAS)
    18     # 将二值图片转换为0 1,存储到二位数组arr中
    19     arr = []
    20     for i in range(width):
    21         pixels = []
    22         for j in range(height):
    23             pixel = int(img.getpixel((j, i)))
    24             pixel = 0 if pixel == 0 else 1
    25             pixels.append(pixel)
    26         arr.append(pixels)
    27 
    28     # 创建txt文件(mac下使用os.mknod()创建文件需要root权限,这里改用复制的方式)
    29     text_image_file = os.path.join(txt_path_prefix, file_name.split('.')[0] + '.txt')
    30     empty_txt_path = "/Users/beiyan/Downloads/empty.txt"
    31     shutil.copyfile(empty_txt_path, text_image_file)
    32 
    33     # 写入文件
    34     with open(text_image_file, 'w') as text_file_object:
    35         for line in arr:
    36             for e in line:
    37                 text_file_object.write(str(e))
    38             text_file_object.write("
    ")

    将所有素材转换为 txt 后,分为两部分:训练样本 和 测试样本。

    三、kNN算法实现

    1、将txt文件转为一维数组的方法:

    1 def img2vector(filename, width, height):
    2     """将txt文件转为一维数组"""
    3     return_vector = np.zeros((1, width * height))
    4     fr = open(filename)
    5     for i in range(height):
    6         line = fr.readline()
    7         for j in range(width):
    8             return_vector[0, height * i + j] = int(line[j])
    9     return return_vector

    2、对测试样本进行kNN分类,返回测试样本的类别:

     1 import numpy as np
     2 import os
     3 import operator
     4 
     5 
     6 # test_set 单个测试样本
     7 # train_set 训练样本二维数组
     8 # labels 训练样本对应的分类
     9 # k k值
    10 def classify(test_set, train_set, labels, k):
    11     """对测试样本进行kNN分类,返回测试样本的类别"""
    12     # 获取训练样本条数
    13     train_size = train_set.shape[0]
    14 
    15     # 计算特征值的差值并求平方
    16     # tile(A,(m,n)),功能是将数组A行重复m次 列重复n次
    17     diff_mat = np.tile(test_set, (train_size, 1)) - train_set
    18     sq_diff_mat = diff_mat ** 2
    19 
    20     # 计算欧式距离 存储到数组 distances
    21     sq_distances = sq_diff_mat.sum(axis=1)
    22     distances = sq_distances ** 0.5
    23 
    24     # 按距离由小到大排序对索引进行排序
    25     sorted_index = distances.argsort()
    26 
    27     # 求距离最短k个样本中 出现最多的分类
    28     class_count = {}
    29     for i in range(k):
    30         near_label = labels[sorted_index[i]]
    31         class_count[near_label] = class_count.get(near_label, 0) + 1
    32     sorted_class_count = sorted(class_count.items(), key=operator.itemgetter(1), reverse=True)
    33     return sorted_class_count[0][0]

    3、统计分类错误率

     1 # train_data_path 训练样本文件夹
     2 # test_data_path 测试样本文件夹
     3 # k k个最近邻居
     4 def get_error_rate(train_data_path, test_data_path, k):
     5     """统计识别错误率"""
     6     width, height = 32, 32
     7     train_labels = []
     8 
     9     training_file_list = os.listdir(train_data_path)
    10     train_size = len(training_file_list)
    11 
    12     # 生成全为0的训练集数组
    13     train_set = np.zeros((train_size, width * height))
    14 
    15     # 读取训练样本
    16     for i in range(train_size):
    17         file = training_file_list[i]
    18         file_name = file.split('.')[0]
    19         label = str(file_name.split('_')[0])
    20         train_labels.append(label)
    21         train_set[i, :] = img2vector(os.path.join(train_data_path, training_file_list[i]), width, height)
    22 
    23     test_file_list = os.listdir(test_data_path)
    24     # 识别错误的个数
    25     error_count = 0.0
    26     # 测试样本的个数
    27     test_count = len(test_file_list)
    28 
    29     # 统计识别错误的个数
    30     for i in range(test_count):
    31         file = test_file_list[i]
    32         true_label = file.split('.')[0].split('_')[0]
    33 
    34         test_set = img2vector(os.path.join(test_data_path, test_file_list[i]), width, height)
    35         test_label = classify(test_set, train_set, train_labels, k)
    36         print(true_label, test_label)
    37         if test_label != true_label:
    38             error_count += 1.0
    39     percent = error_count / float(test_count)
    40     print("识别错误率是:{}".format(str(percent)))

    上述完整代码地址:https://gitee.com/beiyan/machine_learning/tree/master/knn

    4、测试结果

      训练样本:  0-9,a-z,A-Z 共62个字符,每个字符选取120个训练样本 , 一共有7440 个训练样本。每个字符选取20个测试样本,一共1200个测试样本。

      尝试改变条件,测得识别正确率如下:

                                  

    四、kNN算法分析

      由上部分结果可知:knn算法对于手写字母的识别率并不理想。

      原因可能有以下几个方面:

     

      1、图片特征提取过于简单,图片边缘较多空白,且图片中字母的中心位置未必全部对应

      2、因为英文有些字母大小写比较相似,容易识别错误

      3、样本规模较小,每个字符最多只有300个训练样本,真正的训练需要海量数据

    在后序的文章中尝试用其他学习算法提高分类识别率。各位道友有更好的意见也欢迎提出!

  • 相关阅读:
    20100320 ~ 20100420 小结与本月计划
    datamining的思考
    谈谈网络蜘蛛 爬开心网001的一些体会
    将 ASP.NET MVC3 Razor 项目部署到虚拟主机中
    Eclipse代码中中文字显示很小的解决办法
    U8800一键ROOT删除定制软件 安装新版Docment to go
    Android(安卓) U8800 长按 搜索键、返回键 锁屏或解锁的设置方法
    JDK5.0新特性系列3.枚举类型
    JDK5.0新特性系列1.自动装箱和拆箱
    网游运营基本概念及专业术语
  • 原文地址:https://www.cnblogs.com/beiyan/p/8269102.html
Copyright © 2011-2022 走看看