zoukankan      html  css  js  c++  java
  • 基于CNN的手写数字识别程序

    基于CNN的手写数字识别程序

    一、数据准备

    训练及测试数据采用Tensorflow官方提供的MNIST数据集,具体内容如下表所示:

    文件 内容
    图片信息 大小为28*28的灰度手写数字图像,数字从0到9
    train-images-idx3-ubyte.gz 训练集图片,55000张训练图片,5000张验证图片
    train-labels-idx1-ubyte.gz 训练集图片对应的数字标签
    t10k-images-idx3-ubyte.gz 测试集图片,共10000张
    t10k-labels-idx1-ubyte.gz 测试集图片对应的数字标签

    程序中数据导入代码如下:

    from tensorflow.examples.tutorials.mnist import input_data
    minst = input_data.read_data_sets('/tmp/data', one_hot=True)
    

    另外也将Tensorflow中一些常用操作封装成函数,便于调用。

    import tensorflow.compat.v1 as tf
    #权重W初始化函数
    def weight_variable(shape):
        initial = tf.truncated_normal(shape=shape, stddev=0.1)
        #从标准偏差为0.1的正态分布中截取数值进行初始化
        return tf.Variable(initial)
     
    #偏置b初始化函数
    def bias_variable(shape):
        initial = tf.constant(0.1, shape=shape)
        #值为0.1
        return tf.Variable(initial)
    
    

    二、网络结构

    采用LeNet卷积神经网络典型结构,结构图如下:

    2

    1.卷积

    (1)作用

    ​ 以若干个卷积核部分覆盖在输入图上,按照设定的步长遍历输入图并进行卷积运算,从而达到提取图像特征的目的。

    (2)卷积核

    ​ 卷积核一般为尺寸不大于输入图的方阵,以符合一定规律的随机值(如截取自正态分布)进行初始化。本程序设两次卷积操作,使用的卷积核设定如下:

    卷积核 尺寸 输入通道数 输出通道数
    filter_1 5*5 1 32
    filter_2 5*5 32 64
    (3)激活

    ​ 为使分类结果更符合要求,卷积计算后的结果需要再加上偏置b(初始化为0.1)。上述运算均为线性运算。在此之后,采用ReLU非线性函数继续运算,也就是激活。

    (4)程序实现

    将上述步骤封装为一个卷积层函数Conv2d()

    #卷积层函数,返回值为卷积特征图 
    def Conv2d(image, shape):
        #shape为卷积核参数,格式[长,宽,输入通道数,输出通道数]
        w = weight_variable(shape)
        b = bias_variable([shape[3]])
        #shape[3]即shape第三维的值,即输出通道数
        res_conv = tf.nn.conv2d(input=image, filter=w, strides=[1,1,1,1], padding='SAME') + b
        return tf.nn.relu(res_conv)
    
    

    其中函数tf.nn.conv2d(input,filter,strides,padding)是Tensorflow提供的卷积函数,参数说明如下:

    • input,待卷积图像;

    • filter,卷积核;

    • strides,步长,格式为[1,横向步长,纵向步长,1]

    • padding,填充图像边缘('SAME')/不填充图像边缘('VALID')。若填充图像边缘,则使得图像边缘像素也能成为卷积中心。

    2.池化

    (1)作用

    ​ 在输入特征图上滑动一个窗口,取窗口中的某些特定数值(如最大值、平均值)构成池化特征图。池化操作可以概括图像局部信息,改变特征图的大小,但不改变其通道数。

    (2)池化窗

    ​ 类似于卷积核,需要设定尺寸和步长。程序中池化操作均采用大小为5*5,步长为1的池化窗。但不作为一个单独变量,而体现在池化函数的参数中。

    (3)程序实现

    ​ 程序采用最大池化,封装为一个池化函数MaxPool()

    def MaxPool(image):
        return tf.nn.max_pool(image, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME')
    

    实际上只是调用了函数tf.nn.max_pool(image, ksize, strides, padding)参数说明如下:

    • input,输入特征图图像;

    • ksize,池化窗的大小,取一个四维向量,一般是[1, height, width, 1];

    • strides,步长,格式为[1,横向步长,纵向步长,1];

    • padding,填充图像边缘('SAME')/不填充图像边缘('VALID').

    3.全连接

    (1)作用

    ​ 对特征图进行“投票”,从而得到一个特征在各个类别的概率。

    (2)扁平化

    ​ 将经过一系列卷积和池化操作后所得到的特征图展开成一维,便于全连接。

    (3)程序实现
    def Flat(input,size):
        ori_size = int(input.get_shape()[1])
        w = weight_variable(shape = [ori_size, size])
        b = bias_variable(shape = [size])
        return tf.matmul(input, w) + b
    
    #扁平化
    x_flat = tf.reshape(res_pool2, shape=[-1, 7 * 7 * 64])
    res_flat = tf.nn.relu(Flat(x_flat, 1024))
    #全连接
    keep_prob = tf.placeholder(tf.float32)
    full1_drop = tf.nn.dropout(res_flat, keep_prob=keep_prob)
    #dropout防止过拟合,其中keep_prob为神经元保留率
    res_y = Flat(full1_drop, 10)
    #输出分类结果,即0~9共10个标签
    

    三、输出

    1.Softmax

    2.交叉熵损失函数

    3.优化算法Adam

    四、训练结果

  • 相关阅读:
    谁记录了mysql error log中的超长信息(记pt-stalk一个bug的定位过程)
    谈谈MySQL无法连接的原因和分析方法
    MySQL 5.7基于GTID复制的常见问题和修复步骤(二)
    日常运维故障记录和解决
    python学习之-- 故障记录汇总
    sshpass-Linux命令之非交互SSH密码验证
    python 之 线程池实现并发
    python 之 实现su 到root账号
    shell
    shell
  • 原文地址:https://www.cnblogs.com/streamazure/p/13866907.html
Copyright © 2011-2022 走看看