zoukankan      html  css  js  c++  java
  • Tensorflow2.0-mnist手写数字识别示例

    Tensorflow2.0-mnist手写数字识别示例

         

          读书不觉春已深,一寸光阴一寸金。

    简介:通过CNN 卷积神经网络训练后识别出手写图片,测试图片mnist数据集中的0、1、2、4。

                       

    一、mnist数据集准备

         虽然可以通过代码自动下载数据集,但是mnist 数据集国内下载不稳定,会出现【Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz】的情况,代码从定义目录data_set_tf3 中未获取到mnist 数据集就会自动下载,但下载时间比较久,还是提前准备好。

    Downloading mnist data from https

    mnist数据集下载地址

    mnist数据集官网如上,下载下面四个东西就可以了,图中标红的两个images和lables。

    Training set images: train-images-idx3-ubyte.gz (9.9 MB, 解压后 47 MB, 包含 60,000 个样本)

    Training set labels: train-labels-idx1-ubyte.gz (29 KB, 解压后 60 KB, 包含 60,000 个标签)

    Test set images: t10k-images-idx3-ubyte.gz (1.6 MB, 解压后 7.8 MB, 包含 10,000 个样本)

    Test set labels: t10k-labels-idx1-ubyte.gz (5KB, 解压后 10 KB, 包含 10,000 个标签)

          MNIST 数据集来自美国国家标准与技术研究所,  训练集 (training set) 由来自 250 个不同人手写的数字构成, 其中 50% 是高中学生, 50% 来自人口普查局 的工作人员;测试集(test set) 也是同样比例的手写数字数据;可以新建一个文件夹 – mnist, 将数据集下载到 mnist 解压即可。

    mnist数据集整合

    三、图片训练

    train.py 训练代码如下:

     1 import os
     2 import tensorflow as tf
     3 from tensorflow.keras import datasets, layers, models
     4 
     5 '''
     6 python 3.7、3.9
     7 tensorflow 2.0.0b0
     8 '''
     9 
    10 # 模型定义的前半部分主要使用Keras.layers 提供的Conv2D(卷积)与MaxPooling2D(池化)函数。
    11 # CNN的输入是维度为(image_height, image_width, color_channels)的张量,
    12 # mnist数据集是黑白的,因此只有一个color_channels 颜色通道;一般的彩色图片有3个(R, G, B),
    13 # 也有4个通道的(R, G, B, A),A代表透明度;
    14 # 对于mnist数据集,输入的张量维度为(28, 28, 1),通过参数input_shapa 传给网络的第一层
    15 # CNN模型处理:
    16 class CNN(object):
    17     def __init__(self):
    18         model = models.Sequential()
    19         # 第1层卷积,卷积核大小为3*3,32个,28*28为待训练图片的大小
    20         model.add(layers.Conv2D(
    21             32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
    22         model.add(layers.MaxPooling2D((2, 2)))
    23         # 第2层卷积,卷积核大小为3*3,64个
    24         model.add(layers.Conv2D(64, (3, 3), activation='relu'))  # 使用神经网络中激活函数ReLu
    25         model.add(layers.MaxPooling2D((2, 2)))
    26         # 第3层卷积,卷积核大小为3*3,64个
    27         model.add(layers.Conv2D(64, (3, 3), activation='relu'))
    28 
    29         model.add(layers.Flatten())
    30         model.add(layers.Dense(64, activation='relu'))
    31         model.add(layers.Dense(10, activation='softmax'))
    32         # Flatten层用来将输入“压平”,即把多维的输入一维化,常用在从卷积层到全连接层的过渡。Flatten不影响batch的大小
    33         # dense :全连接层相当于添加一个层
    34         # softmax用于多分类过程中,它将多个神经元的输出,映射到(0,1)区间内,可以看成概率来理解,从而来进行多分类!
    35         model.summary()  # 输出模型各层的参数状况
    36 
    37         self.model = model
    38 
    39 
    40 # mnist数据集预处理
    41 class DataSource(object):
    42     def __init__(self):
    43         # mnist数据集存储的位置,如果不存在将自动下载
    44         data_path = os.path.abspath(os.path.dirname(
    45             __file__)) + '/../data_set_tf2/mnist.npz'
    46         (train_images, train_labels), (test_images,
    47                                        test_labels) = datasets.mnist.load_data(path=data_path)
    48         # 6万张训练图片,1万张测试图片
    49         train_images = train_images.reshape((60000, 28, 28, 1))
    50         test_images = test_images.reshape((10000, 28, 28, 1))
    51         # 像素值映射到 0 - 1 之间
    52         train_images, test_images = train_images / 255.0, test_images / 255.0
    53 
    54         self.train_images, self.train_labels = train_images, train_labels
    55         self.test_images, self.test_labels = test_images, test_labels
    56 
    57 
    58 # 开始训练并保存训练结果
    59 class Train:
    60     def __init__(self):
    61         self.cnn = CNN()
    62         self.data = DataSource()
    63 
    64     def train(self):
    65         check_path = './ckpt/cp-{epoch:04d}.ckpt'
    66         # period 每隔5epoch保存一次
    67         save_model_cb = tf.keras.callbacks.ModelCheckpoint(
    68             check_path, save_weights_only=True, verbose=1, period=5)
    69 
    70         self.cnn.model.compile(optimizer='adam',
    71                                loss='sparse_categorical_crossentropy',
    72                                metrics=['accuracy'])
    73         self.cnn.model.fit(self.data.train_images, self.data.train_labels,
    74                            epochs=5, callbacks=[save_model_cb])
    75 
    76         test_loss, test_acc = self.cnn.model.evaluate(
    77             self.data.test_images, self.data.test_labels)
    78         print("准确率: %.4f,共测试了%d张图片 " % (test_acc, len(self.data.test_labels)))
    79 
    80 
    81 if __name__ == "__main__":
    82     app = Train()
    83     app.train()
    View Code~拍一拍小轮胎

     mnist手写数字识别训练了四分钟左右,准确率高达0.9902,下面的视频只截取了训练的前十秒。

     mnist手写数字识别训练视频


    model.summary()打印定义的模型结构

    CNN定义的模型结构

     1 Model: "sequential"
     2 _________________________________________________________________
     3 Layer (type)                 Output Shape              Param #   
     4 =================================================================
     5 conv2d (Conv2D)              (None, 26, 26, 32)        320       
     6 _________________________________________________________________
     7 max_pooling2d (MaxPooling2D) (None, 13, 13, 32)        0         
     8 _________________________________________________________________
     9 conv2d_1 (Conv2D)            (None, 11, 11, 64)        18496     
    10 _________________________________________________________________
    11 max_pooling2d_1 (MaxPooling2 (None, 5, 5, 64)          0         
    12 _________________________________________________________________
    13 conv2d_2 (Conv2D)            (None, 3, 3, 64)          36928     
    14 _________________________________________________________________
    15 flatten (Flatten)            (None, 576)               0         
    16 _________________________________________________________________
    17 dense (Dense)                (None, 64)                36928     
    18 _________________________________________________________________
    19 dense_1 (Dense)              (None, 10)                650       
    20 =================================================================
    21 Total params: 93,322
    22 Trainable params: 93,322
    23 Non-trainable params: 0
    24 _________________________________________________________________
    View Code

          我们可以看到,每一个Conv2D 和MaxPooling2D 层的输出都是一个三维的张量(height, width, channels),height 和width 会逐渐地变小;输出的channel 的个数,是由第一个参数(例如,32或64)控制的,随着height 和width 的变小,channel可以变大(从算力的角度)。

          模型的后半部分,是定义张量的输出。layers.Flatten 会将三维的张量转为一维的向量,展开前张量的维度是(3, 3, 64) ,转为一维(576)【3*3*64】的向量后,紧接着使用layers.Dense 层,构造了2层全连接层,逐步地将一维向量的位数从576变为64,再变为10。

          后半部分相当于是构建了一个隐藏层为64,输入层为576,输出层为10的普通的神经网络。最后一层的激活函数是softmax,10位恰好可以表达0-9十个数字。最大值的下标即可代表对应的数字,使用numpy 的argmax() 方法获取最大值下标,很容易计算得到预测值。

    train.py运行结果

          可以看到,在第一轮训练后,识别准确率达到了0.9536,五轮训练之后,使用测试集验证,准确率达到了0.9902。在第五轮时,模型参数成功保存在了./ckpt/cp-0005.ckpt,而且此时准确率为更高的0.9940,所以也并不是训练时间次数越久越好,过犹不及。可以加载保存的模型参数,恢复整个卷积神经网络,进行真实图片的预测。

    保存训练模型参数

    四、图片预测

    predict.py代码如下:

     1 import tensorflow as tf
     2 from PIL import Image
     3 import numpy as np
     4 
     5 from mnist.v4_cnn.train import CNN
     6 
     7 '''
     8 python 3.7 3.9
     9 tensorflow 2.0.0b0
    10 pillow(PIL) 4.3.0
    11 '''
    12 
    13 
    14 class Predict(object):
    15     def __init__(self):
    16         latest = tf.train.latest_checkpoint('./ckpt')
    17         self.cnn = CNN()
    18         # 恢复网络权重
    19         self.cnn.model.load_weights(latest)
    20 
    21     def predict(self, image_path):
    22         # 以黑白方式读取图片
    23         img = Image.open(image_path).convert('L')
    24         img = np.reshape(img, (28, 28, 1)) / 255.
    25         x = np.array([1 - img])
    26 
    27         # API refer: https://keras.io/models/model/
    28         y = self.cnn.model.predict(x)
    29 
    30         # 因为x只传入了一张图片,取y[0]即可
    31         # np.argmax()取得最大值的下标,即代表的数字
    32         print(image_path)
    33         print(y[0])
    34         print('        -> Predict picture number is: ', np.argmax(y[0]))
    35 
    36 
    37 if __name__ == "__main__":
    38     app = Predict()
    39     app.predict('../test_images/0.png')
    40     app.predict('../test_images/1.png')
    41     app.predict('../test_images/4.png')
    42     app.predict('../test_images/2.png')
    View Code

    预测结果

     预测结果:

     1 ../test_images/0.png
     2 [9.9999774e-01 2.6819215e-08 1.2541744e-07 8.7437911e-08 1.0661940e-09
     3  3.3693670e-08 4.6488995e-07 3.5915035e-09 9.8040758e-08 1.4385278e-06]
     4         -> Predict picture number is:  0
     5 ../test_images/1.png
     6 [7.75440956e-09 9.99991298e-01 1.41642090e-07 1.09819875e-10
     7  6.76554646e-06 7.63710162e-09 2.37024622e-08 1.58189516e-06
     8  2.49125264e-07 4.92376007e-09]
     9         -> Predict picture number is:  1
    10 ../test_images/4.png
    11 [7.03467840e-10 8.20740708e-04 1.11648405e-04 3.93262711e-09
    12  9.99048650e-01 1.08713095e-07 4.24647197e-08 1.85665340e-05
    13  5.03181887e-08 1.86591734e-07]
    14         -> Predict picture number is:  4
    15 ../test_images/2.png
    16 [1.5828672e-08 1.9245699e-07 9.9999440e-01 5.3448480e-06 1.7397912e-10
    17  8.6148493e-13 2.5441890e-10 5.3953073e-08 3.5735226e-08 8.9734775e-11]
    18         -> Predict picture number is:  2
    View Code

    如上,经CNN训练后通过模型参数准确预测出了0、1、2、4四张手写图片的真实值。

                     

        

     读书不觉春已深

                                一寸光阴一寸金

     

  • 相关阅读:
    Intellij idea安装
    c# .net 我的Application_Error 全局异常抓取处理
    c# .net Global.asax文件的作用
    ASP.NET机制详细的管道事件流程(转)
    正则表达式_学习笔记
    c# .net获取随机字符串!
    c# 动态调用WCF方法笔记!
    Web Service和WCF的区别。其实二者不属于一个范畴!!!
    c# .net获取文件夹下的所有文件(多层递归),并获取区间数据(Jsion,xml等数据)写出到处理文件,学习分享~
    c#.net单例模式的学习记录!
  • 原文地址:https://www.cnblogs.com/taojietaoge/p/14202887.html
Copyright © 2011-2022 走看看