zoukankan      html  css  js  c++  java
  • mnist手写数字问题初体验

    上一篇我们提到了回归问题中的梯度下降算法,而且我们知道线性模型只能解决简单的线性回归问题,对于高维图片,线性模型不能完成这样复杂的分类任务。那么是不是线性模型在离散值预测或图像分类问题中就没有用武之地了呢?

    本篇我们就套用regression中的部分机制来处理classification中的问题。

    在这里首先介绍一下激活函数。

    所谓激活函数,实际上就是引入非线性因子,将线性模型去线性化,增强模型的表达能力。ReLU激活函数是我要介绍的第一个激活函数,其定义式为φ(z)=max{0,z},图像表示如下:

    简单的说relu就是一个取最大值的函数,在负区间取值为0,正区间取值不变,这种操作被称为单侧抑制(输出为0时代表神经元不会被激活)。单侧抑制的特点就是同一时间只会有一部分神经元被激活(结合函数图像可以看出),也就使得神经元具有了稀疏激活性。加入relu激活函数的神经元被称作整流线性单元,它与线性单元非常相似,唯一的区别就是在一半定义域上输出为0。整流线性单元易于优化,当其处于激活状态时(输出不为0),它的一阶导数能够保持一个较大值(等于1),并且处处一致,它的二阶导数几乎处处为0,这样的好处就是避免了梯度下降时的梯度消失问题(可参考前一篇回归问题的随笔)。

    简单介绍了激活函数,那么是不是将激活函数引入我们的线性模型out=X@w+b就能使其解决复杂的图像分类问题了呢?

    很显然不是的,虽然加了激活函数,但是我们可以看到模型变为out=relu(X@w+b) 依然还是太简单。那么怎么办呢?

    我们可以联系一下零件加工的流程,从原料到成品,零件的加工经历了多个工序,期间每一道工序都是由前一道工序为基础,这时候,原料就相当于神经网络的输入,成品零件就相当于神经网络的输出,他们中间并不是也不能一步到位,而是经过若干“隐藏”的工序一步一步的生成产品。我们的模型同样可以借助于这种思想。即给数据处理多添加几道所谓的“工序”,我们称之为“隐藏层”,因为我们关心的只有模型的输入和输出,隐藏层的数据是我们不可见的(当然也可以在运行过程中打印出来方便调试),下面我们就利用这样的思想来解决mnist手写数字分类问题。

    我们使用的是mnist数据集,也是深度学习的基础入门数据集。它一共有70k张不同的手写数字图片,其中60k用来训练模型,10k用来评估模型,且所有图片均为28*28的灰度图。我们首先设计一个稍微复杂的模型

    h1=relu(X@w1+b1)

    h2=relu(h1@w2+b2)

    out=relu(h2@w3+b3)

    其中X为输入,out为输出,h1、h2均为隐藏层,且除输入层外每一层的输入均是前一层的输出。

    首先我们将输入的28*28*1的图片扁平化,即将每张图片转化成784维的向量(28*28=784),这样的好处是可以以矩阵的形式同时喂入多张图片(每一行向量为一张灰度图的信息),提高效率。对于输出out,我们令其输出一个10维的向量,代表10个数字的概率。模型可以用以下公式概括:

    out=relu { relu { relu[ X@w1+b1 ] @w2+b2 }@w3+b3 }

    pred=argmax(out)

    loss=MSE(out,label) (均方误差损失函数即loss=∑(label-out)2

    minimize loss→[w1',b1',w2',b2',w3',b3']

    参数调整完成后,可以对新的输入x进行运算从而得到对应的输出

    代码如下:

     1 import os
     2 import tensorflow as tf
     3 from tensorflow import keras
     4 from tensorflow.keras import layers, optimizers, datasets
     5 
     6 # 屏蔽通知和警告信息,减少用处不大的问题输出
     7 os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
     8 
     9 (x, y), (x_val, y_val) = datasets.mnist.load_data() 
    10 x = tf.convert_to_tensor(x, dtype=tf.float32) / 255.
    11 y = tf.convert_to_tensor(y, dtype=tf.int32)
    12 y = tf.one_hot(y, depth=10)
    13 print(x.shape, y.shape)
    14 train_dataset = tf.data.Dataset.from_tensor_slices((x, y))
    15 train_dataset = train_dataset.batch(200)
    16 
    17 # 搭建网络结构
    18 model = keras.Sequential([ 
    19     layers.Dense(512, activation='relu'),
    20     layers.Dense(256, activation='relu'),
    21     layers.Dense(10)])
    22 
    23 # 初始化优化器为梯度下降优化器
    24 optimizer = optimizers.SGD(learning_rate=0.001)
    25 
    26 def train_epoch(epoch):
    27 
    28     # Step4.循环迭代
    29     for step, (x, y) in enumerate(train_dataset):
    30 
    31 
    32         with tf.GradientTape() as tape:
    33             # 将输入数据压平 [b, 28, 28] => [b, 784]
    34             x = tf.reshape(x, (-1, 28*28))
    35             # Step1. 计算输出
    36             # 输入域数据经过神经网络降维 [b, 784] => [b, 10]
    37             out = model(x)
    38             # Step2. 计算损失
    39             loss = tf.reduce_sum(tf.square(out - y)) / x.shape[0]
    40 
    41         # Step3. 优化更新参数 w1, w2, w3, b1, b2, b3
    42         grads = tape.gradient(loss, model.trainable_variables)
    43         # w' = w - lr * grad
    44         optimizer.apply_gradients(zip(grads, model.trainable_variables))
    45 
    46         if step % 100 == 0:
    47             print(epoch, step, 'loss:', loss.numpy())
    48 
    49 def train():
    50 
    51     for epoch in range(30):
    52 
    53         train_epoch(epoch)
    54 
    55 
    56 if __name__ == '__main__':
    57     train()

    运行结果如下:

       

    可以看到损失从初始的1.65降到0.25,在这里我们先只对mnist进行一个初步探索,测试一下模型的表现,后续会通过一些更好的优化方法来不断改良我们的模型。

  • 相关阅读:
    BOI 2002 双调路径
    BOI'98 DAY 2 TASK 1 CONFERENCE CALL Dijkstra/Dijkstra+priority_queue/SPFA
    USACO 2013 November Contest, Silver Problem 2. Crowded Cows 单调队列
    BOI 2003 Problem. Spaceship
    USACO 2006 November Contest Problem. Road Blocks SPFA
    CEOI 2004 Trial session Problem. Journey DFS
    USACO 2015 January Contest, Silver Problem 2. Cow Routing Dijkstra
    LG P1233 木棍加工 动态规划,Dilworth
    LG P1020 导弹拦截 Dilworth
    USACO 2007 February Contest, Silver Problem 3. Silver Cow Party SPFA
  • 原文地址:https://www.cnblogs.com/zdm-code/p/12190537.html
Copyright © 2011-2022 走看看