zoukankan      html  css  js  c++  java
  • Tensorflow2.0-mnist手写数字识别示例

    Tensorflow2.0-mnist手写数字识别示例

         

          读书不觉春已深,一寸光阴一寸金。

    简介:通过CNN 卷积神经网络训练后识别出手写图片,测试图片mnist数据集中的0、1、2、4。

                       

    一、mnist数据集准备

         虽然可以通过代码自动下载数据集,但是mnist 数据集国内下载不稳定,会出现【Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz】的情况,代码从定义目录data_set_tf3 中未获取到mnist 数据集就会自动下载,但下载时间比较久,还是提前准备好。

    Downloading mnist data from https

    mnist数据集下载地址

    mnist数据集官网如上,下载下面四个东西就可以了,图中标红的两个images和lables。

    Training set images: train-images-idx3-ubyte.gz (9.9 MB, 解压后 47 MB, 包含 60,000 个样本)

    Training set labels: train-labels-idx1-ubyte.gz (29 KB, 解压后 60 KB, 包含 60,000 个标签)

    Test set images: t10k-images-idx3-ubyte.gz (1.6 MB, 解压后 7.8 MB, 包含 10,000 个样本)

    Test set labels: t10k-labels-idx1-ubyte.gz (5KB, 解压后 10 KB, 包含 10,000 个标签)

          MNIST 数据集来自美国国家标准与技术研究所,  训练集 (training set) 由来自 250 个不同人手写的数字构成, 其中 50% 是高中学生, 50% 来自人口普查局 的工作人员;测试集(test set) 也是同样比例的手写数字数据;可以新建一个文件夹 – mnist, 将数据集下载到 mnist 解压即可。

    mnist数据集整合

    三、图片训练

    train.py 训练代码如下:

     1 import os
     2 import tensorflow as tf
     3 from tensorflow.keras import datasets, layers, models
     4 
     5 '''
     6 python 3.7、3.9
     7 tensorflow 2.0.0b0
     8 '''
     9 
    10 # 模型定义的前半部分主要使用Keras.layers 提供的Conv2D(卷积)与MaxPooling2D(池化)函数。
    11 # CNN的输入是维度为(image_height, image_width, color_channels)的张量,
    12 # mnist数据集是黑白的,因此只有一个color_channels 颜色通道;一般的彩色图片有3个(R, G, B),
    13 # 也有4个通道的(R, G, B, A),A代表透明度;
    14 # 对于mnist数据集,输入的张量维度为(28, 28, 1),通过参数input_shapa 传给网络的第一层
    15 # CNN模型处理:
    16 class CNN(object):
    17     def __init__(self):
    18         model = models.Sequential()
    19         # 第1层卷积,卷积核大小为3*3,32个,28*28为待训练图片的大小
    20         model.add(layers.Conv2D(
    21             32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
    22         model.add(layers.MaxPooling2D((2, 2)))
    23         # 第2层卷积,卷积核大小为3*3,64个
    24         model.add(layers.Conv2D(64, (3, 3), activation='relu'))  # 使用神经网络中激活函数ReLu
    25         model.add(layers.MaxPooling2D((2, 2)))
    26         # 第3层卷积,卷积核大小为3*3,64个
    27         model.add(layers.Conv2D(64, (3, 3), activation='relu'))
    28 
    29         model.add(layers.Flatten())
    30         model.add(layers.Dense(64, activation='relu'))
    31         model.add(layers.Dense(10, activation='softmax'))
    32         # Flatten层用来将输入“压平”,即把多维的输入一维化,常用在从卷积层到全连接层的过渡。Flatten不影响batch的大小
    33         # dense :全连接层相当于添加一个层
    34         # softmax用于多分类过程中,它将多个神经元的输出,映射到(0,1)区间内,可以看成概率来理解,从而来进行多分类!
    35         model.summary()  # 输出模型各层的参数状况
    36 
    37         self.model = model
    38 
    39 
    40 # mnist数据集预处理
    41 class DataSource(object):
    42     def __init__(self):
    43         # mnist数据集存储的位置,如果不存在将自动下载
    44         data_path = os.path.abspath(os.path.dirname(
    45             __file__)) + '/../data_set_tf2/mnist.npz'
    46         (train_images, train_labels), (test_images,
    47                                        test_labels) = datasets.mnist.load_data(path=data_path)
    48         # 6万张训练图片,1万张测试图片
    49         train_images = train_images.reshape((60000, 28, 28, 1))
    50         test_images = test_images.reshape((10000, 28, 28, 1))
    51         # 像素值映射到 0 - 1 之间
    52         train_images, test_images = train_images / 255.0, test_images / 255.0
    53 
    54         self.train_images, self.train_labels = train_images, train_labels
    55         self.test_images, self.test_labels = test_images, test_labels
    56 
    57 
    58 # 开始训练并保存训练结果
    59 class Train:
    60     def __init__(self):
    61         self.cnn = CNN()
    62         self.data = DataSource()
    63 
    64     def train(self):
    65         check_path = './ckpt/cp-{epoch:04d}.ckpt'
    66         # period 每隔5epoch保存一次
    67         save_model_cb = tf.keras.callbacks.ModelCheckpoint(
    68             check_path, save_weights_only=True, verbose=1, period=5)
    69 
    70         self.cnn.model.compile(optimizer='adam',
    71                                loss='sparse_categorical_crossentropy',
    72                                metrics=['accuracy'])
    73         self.cnn.model.fit(self.data.train_images, self.data.train_labels,
    74                            epochs=5, callbacks=[save_model_cb])
    75 
    76         test_loss, test_acc = self.cnn.model.evaluate(
    77             self.data.test_images, self.data.test_labels)
    78         print("准确率: %.4f,共测试了%d张图片 " % (test_acc, len(self.data.test_labels)))
    79 
    80 
    81 if __name__ == "__main__":
    82     app = Train()
    83     app.train()
    View Code~拍一拍小轮胎

     mnist手写数字识别训练了四分钟左右,准确率高达0.9902,下面的视频只截取了训练的前十秒。

     mnist手写数字识别训练视频


    model.summary()打印定义的模型结构

    CNN定义的模型结构

     1 Model: "sequential"
     2 _________________________________________________________________
     3 Layer (type)                 Output Shape              Param #   
     4 =================================================================
     5 conv2d (Conv2D)              (None, 26, 26, 32)        320       
     6 _________________________________________________________________
     7 max_pooling2d (MaxPooling2D) (None, 13, 13, 32)        0         
     8 _________________________________________________________________
     9 conv2d_1 (Conv2D)            (None, 11, 11, 64)        18496     
    10 _________________________________________________________________
    11 max_pooling2d_1 (MaxPooling2 (None, 5, 5, 64)          0         
    12 _________________________________________________________________
    13 conv2d_2 (Conv2D)            (None, 3, 3, 64)          36928     
    14 _________________________________________________________________
    15 flatten (Flatten)            (None, 576)               0         
    16 _________________________________________________________________
    17 dense (Dense)                (None, 64)                36928     
    18 _________________________________________________________________
    19 dense_1 (Dense)              (None, 10)                650       
    20 =================================================================
    21 Total params: 93,322
    22 Trainable params: 93,322
    23 Non-trainable params: 0
    24 _________________________________________________________________
    View Code

          我们可以看到,每一个Conv2D 和MaxPooling2D 层的输出都是一个三维的张量(height, width, channels),height 和width 会逐渐地变小;输出的channel 的个数,是由第一个参数(例如,32或64)控制的,随着height 和width 的变小,channel可以变大(从算力的角度)。

          模型的后半部分,是定义张量的输出。layers.Flatten 会将三维的张量转为一维的向量,展开前张量的维度是(3, 3, 64) ,转为一维(576)【3*3*64】的向量后,紧接着使用layers.Dense 层,构造了2层全连接层,逐步地将一维向量的位数从576变为64,再变为10。

          后半部分相当于是构建了一个隐藏层为64,输入层为576,输出层为10的普通的神经网络。最后一层的激活函数是softmax,10位恰好可以表达0-9十个数字。最大值的下标即可代表对应的数字,使用numpy 的argmax() 方法获取最大值下标,很容易计算得到预测值。

    train.py运行结果

          可以看到,在第一轮训练后,识别准确率达到了0.9536,五轮训练之后,使用测试集验证,准确率达到了0.9902。在第五轮时,模型参数成功保存在了./ckpt/cp-0005.ckpt,而且此时准确率为更高的0.9940,所以也并不是训练时间次数越久越好,过犹不及。可以加载保存的模型参数,恢复整个卷积神经网络,进行真实图片的预测。

    保存训练模型参数

    四、图片预测

    predict.py代码如下:

     1 import tensorflow as tf
     2 from PIL import Image
     3 import numpy as np
     4 
     5 from mnist.v4_cnn.train import CNN
     6 
     7 '''
     8 python 3.7 3.9
     9 tensorflow 2.0.0b0
    10 pillow(PIL) 4.3.0
    11 '''
    12 
    13 
    14 class Predict(object):
    15     def __init__(self):
    16         latest = tf.train.latest_checkpoint('./ckpt')
    17         self.cnn = CNN()
    18         # 恢复网络权重
    19         self.cnn.model.load_weights(latest)
    20 
    21     def predict(self, image_path):
    22         # 以黑白方式读取图片
    23         img = Image.open(image_path).convert('L')
    24         img = np.reshape(img, (28, 28, 1)) / 255.
    25         x = np.array([1 - img])
    26 
    27         # API refer: https://keras.io/models/model/
    28         y = self.cnn.model.predict(x)
    29 
    30         # 因为x只传入了一张图片,取y[0]即可
    31         # np.argmax()取得最大值的下标,即代表的数字
    32         print(image_path)
    33         print(y[0])
    34         print('        -> Predict picture number is: ', np.argmax(y[0]))
    35 
    36 
    37 if __name__ == "__main__":
    38     app = Predict()
    39     app.predict('../test_images/0.png')
    40     app.predict('../test_images/1.png')
    41     app.predict('../test_images/4.png')
    42     app.predict('../test_images/2.png')
    View Code

    预测结果

     预测结果:

     1 ../test_images/0.png
     2 [9.9999774e-01 2.6819215e-08 1.2541744e-07 8.7437911e-08 1.0661940e-09
     3  3.3693670e-08 4.6488995e-07 3.5915035e-09 9.8040758e-08 1.4385278e-06]
     4         -> Predict picture number is:  0
     5 ../test_images/1.png
     6 [7.75440956e-09 9.99991298e-01 1.41642090e-07 1.09819875e-10
     7  6.76554646e-06 7.63710162e-09 2.37024622e-08 1.58189516e-06
     8  2.49125264e-07 4.92376007e-09]
     9         -> Predict picture number is:  1
    10 ../test_images/4.png
    11 [7.03467840e-10 8.20740708e-04 1.11648405e-04 3.93262711e-09
    12  9.99048650e-01 1.08713095e-07 4.24647197e-08 1.85665340e-05
    13  5.03181887e-08 1.86591734e-07]
    14         -> Predict picture number is:  4
    15 ../test_images/2.png
    16 [1.5828672e-08 1.9245699e-07 9.9999440e-01 5.3448480e-06 1.7397912e-10
    17  8.6148493e-13 2.5441890e-10 5.3953073e-08 3.5735226e-08 8.9734775e-11]
    18         -> Predict picture number is:  2
    View Code

    如上,经CNN训练后通过模型参数准确预测出了0、1、2、4四张手写图片的真实值。

                     

        

     读书不觉春已深

                                一寸光阴一寸金

     

  • 相关阅读:
    MDX Step by Step 读书笔记(六) Building Complex Sets (复杂集合的处理) Filtering Sets
    在 Visual Studio 2012 开发 SSIS,SSAS,SSRS BI 项目
    微软BI 之SSIS 系列 在 SSIS 中读取 SharePoint List
    MDX Step by Step 读书笔记(五) Working with Expressions (MDX 表达式) Infinite Recursion 和 SOLVE_ORDER 原理解析
    MDX Step by Step 读书笔记(五) Working with Expressions (MDX 表达式)
    使用 SQL Server 2012 Analysis Services Tabular Mode 表格建模 图文教程
    MDX Step by Step 读书笔记(四) Working with Sets (使用集合) Limiting Set and AutoExists
    SQL Server 2012 Analysis Services Tabular Model 读书笔记
    Microsoft SQL Server 2008 MDX Step by Step 学习笔记连载目录
    2011新的开始,介绍一下AgileEAS.NET平台在新的一年中的发展方向
  • 原文地址:https://www.cnblogs.com/taojietaoge/p/14202887.html
Copyright © 2011-2022 走看看