zoukankan      html  css  js  c++  java
  • Tensorflow2.0笔记20——MNIST数据集(手写体数字识别)

    Tensorflow2.0笔记

    本博客为Tensorflow2.0学习笔记,感谢北京大学微电子学院曹建老师

    3.MNIST数据集(手写体数字识别)

    3.1 简介

    ​ MNIST 数据集一共有 7 万张图片,是 28×28 像素的 0 到 9 手写数字数据集, 其中 6 万张用于训练,1 万张用于测试。每张图片包括 784(28×28)个像素点, 使用全连接网络时可将 784 个像素点组成长度为 784 的一维数组,作为输入特征。数据集图片如下所示。

    image-20210622203636441

    3.1 导入数据集

    ​ keras 函数库中提供了使用 mnist 数据集的接口,代码如下所示,可以使用load_data()直接从 mnist 中读取测试集和训练集。

    mnist = tf.keras.datasets.mnist
    (x_train, y_train), (x_test, y_test) = mnist.load_data() 
    
    输入全连接网络时需要先将数据拉直为一维数组,把 784 个像素点的灰度值作为输入特征输入神经网络。 
    
    tf.keras.layers.Flatten()
    

    ​ 使用 plt 库中的两个函数可视化训练集中的图片。

    plt.imshow(x_train[0],cmap=’gray’)
    plt.show()
    

    image-20210622203933878

    ​ 使用 print 打印出训练集中第一个样本以二位数组的形式打印出来,如下所示。

    print(“x_train[0]:”,x_train[0]) 
    

    image-20210622204003482

    ​ 打印出第一个样本的标签,为 5。

    print("y_train[0]:",y_train[0]) y_train[0]:5 	
    

    ​ 打印出测试集样本的形状,共有 10000 个 28 行 28 列的三维数据。

    print(“x_test.shape:”x_test.shape) x_test.shape:(10000,28,28) 
    

    3.3训练MNIST数据集

    使用 Sequential 实现手写数字识别

    image-20210622204112539

    使用 class 实现手写数字识别

    image-20210622204125212

    值得注意的是训练时需要将输入特征的灰度值归一化到[0,1]区间,这可以使网络更快收敛。

    image-20210622204136696

    ​ 训练时每个 step 给出的是训练集 accuracy 不具有参考价值,有实际评判价值的是 validation_freq 中设置的隔若干轮输出的测试集 accuracy。如下图所示

  • 相关阅读:
    mirco新建proto流程
    Ubuntu默认防火墙安装、启用、配置、端口、查看状态相关信息
    Rails核心组件
    Ruby中文乱码问题
    python str转dict
    SQLserver AwaysOn日志文件过大,处理办法
    MySQL的一些小细节
    mysql删除表中重复值
    可恶的自增长标识符
    reset slave all更彻底
  • 原文地址:https://www.cnblogs.com/wind-and-sky/p/14920267.html
Copyright © 2011-2022 走看看