zoukankan      html  css  js  c++  java
  • 一起学TensorFlow---搭建最简单的全连接网络实现手写数字识别(MNIST)

    刚开始学Tensorflow,这里记录学习中的点点滴滴,希望能和大家共同进步。

    Cuda和Tensorflow的安装请参考上一篇博客:http://www.cnblogs.com/roboai/p/7768191.html

    Tensorflow简单介绍

      我们知道,一维的数据可以用数组表示,二维可以用矩阵表示,那么三维或三维以上呢?比如图像,实际上就是一个三维数据[h,w,c],高、宽、通道数,对于灰度图来说,通道数为1,而对于彩色图像,通道数为3。对于这种三维或三维以上的数据,我们称之为张量(tensor),所以顾名思义,Tensorflow的意思就是张量的流动,Tensorflow将数据打包成一个个张量,由四个维度构成,分别是[batch, height, width, channels],然后在各个节点之间传递。

      节点是Tensorflow里另一重要的概念,对张量的操作称之为节点,一系列的节点构成图。接触过Caffe的朋友可能发现了,这和Caffe里的blob、layer、net是一致的。不同的是,我们需要启动一个会话来计算图,这是Tensorflow的内在机制所决定的。Tensorflow依赖于一个高效的C++后端来进行计算,与后端的这个连接叫做session。一般而言,使用TensorFlow程序的流程是先创建一个图,然后在session中启动它。其思想是先让我们描述一个交互操作图,然后完全将其运行在Python外部。这样做的目的是为了避免频繁切换Python环境和外部环境时需要的开销。如果你想在GPU或者分布式环境中计算时,这一开销会非常可怖,这一开销主要可能是用来进行数据迁移,并不能对计算做出贡献。

      我们构建一个简单的图来说明以上过程,改图包含三个节点(两个源节点和一个矩阵乘法节点),然后启动一个会话计算图得到输出结果,最后需要关闭会话。当然也可以使用with代码块实现自动关闭,效果是一样的。

    # coding=utf-8
    import tensorflow as tf
    
    # 该图包含3个节点(两个源节点和乘法节点)
    matrix1 = tf.constant([[3, 3]])
    matrix2 = tf.constant([[2], [2]])
    product = tf.matmul(matrix1, matrix2)
    
    # 调用会话启动图
    sess = tf.Session()
    result = sess.run(product)
    
    # 输出结果并关闭会话
    print result
    sess.close()
    
    # 使用“with”代码块自动关闭, 该方法更简洁
    with tf.Session() as sess:
        result = sess.run(product)
        print result

    输出结果为

    [[12]]
    [[12]]

    MNIST数据集

      MNIST是一个入门级的计算机视觉数据集,它包含各种手写数字图片,也包含每一张图片对应的标签,告诉我们这个是数字几。新建一个get.sh文件,写入以下内容,执行该文件就可以下载该数据集。下载下来的数据集被分成两部分,60000行的训练数据集和10000行的测试数据集。每一张图片包含28X28个像素点,我们可以把图片展开成一个向量,长度是 28x28 = 784。

    #!/usr/bin/env sh
    # This scripts downloads the mnist data and unzips it.
    
    DIR="$( cd "$(dirname "$0")" ; pwd -P )"
    cd "$DIR"
    
    echo "Downloading..."
    
    for fname in train-images-idx3-ubyte train-labels-idx1-ubyte t10k-images-idx3-ubyte t10k-labels-idx1-ubyte
    do
        if [ ! -e $fname ]; then
            wget --no-check-certificate http://yann.lecun.com/exdb/mnist/${fname}.gz
        fi
    done

    Softmax Regression与Cross Entropy

      在本文中,我们将采用最简单的网络来预测输入图片中的数字,整个网络仅由一个Softmax Regression构成,数学模型可以写作(y=softmax(Wx+b))。假设(y')是实际分布,(y)是预测分布,Cross Entropy的定义是(loss=sum{y'log{y}})。关于Softmax Regression的反向传递及Cross Entropy的物理含义请参考以下两篇博客,这里就不展开写了。

      http://ufldl.stanford.edu/wiki/index.php/Softmax%E5%9B%9E%E5%BD%92  

      http://blog.csdn.net/rtygbwwwerr/article/details/50778098

    全连接网络实现手写数字识别

      下面终于进入正题了,我们有了数据集,同时也了解了算法流程,剩下的就是写代码实现了。首先是导入包,由于Tensorflow帮我们写了一部分数据读写的程序,我们这里就直接用了。

    # coding=utf-8
    import tensorflow.examples.tutorials.mnist.input_data as input_data
    import tensorflow as tf
    
    # 导入数据, 强烈建议预先下载
    mnist = input_data.read_data_sets("data/", one_hot=True)

      这里数据可以用我前面给出的get.sh下载,然后放入data文件夹目录下,我之前是直接用input_data.read_data_sets("data/", one_hot=True)下载的,结果半天下载不下来,所以这里还是建议预先下载吧,用get.sh下载比较快。然后是程序的主要部分。

    # 训练集占位符:28*28=784
    x = tf.placeholder(tf.float32, [None, 784])
    # 初始化参数
    W = tf.Variable(tf.zeros([784, 10]))
    b = tf.Variable(tf.zeros([10]))
    # 输出结果
    y = tf.nn.softmax(tf.matmul(x, W) + b)
    # 真实值
    y_ = tf.placeholder(tf.float32, [None, 10])
    # 计算交叉熵
    crossEntropy = -tf.reduce_sum(y_*tf.log(y))
    # 训练策略
    trainStep = tf.train.GradientDescentOptimizer(0.01).minimize(crossEntropy)
    # 初始化参数值
    init = tf.global_variables_initializer()
    sess = tf.Session()
    sess.run(init)
    
    # 开始训练:循环训练1000次
    for i in range(1000):
        batchXs, batchYs = mnist.train.next_batch(100)
        sess.run(trainStep, feed_dict={x: batchXs, y_: batchYs})
    
    # 评估模型
    correctPrediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
    accuracy = tf.reduce_mean(tf.cast(correctPrediction, tf.float32))
    print sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})

      这里用的是占位符的方式传入数据,占位符的尺寸为[None, 784],这里的None表示此张量的第一个维度可以是任何长度的。

      权重值W和偏置量b使用Variable来表示,一个Variable代表一个可修改的张量,存在在Tensorflow的用于描述交互性操作的图中。它们可以用于计算输入值,也可以在计算中被修改。对于各种机器学习应用,一般都会有模型参数,都可以用Variable表示。在这里,我们都用全为零的张量来初始化Wb。

      只需要一行代码就可以实现我们的模型y = tf.nn.softmax(tf.matmul(x, W) + b),同样损失函数也只需要一行代码crossEntropy = -tf.reduce_sum(y_*tf.log(y))。

      以0.01的学习速率,采用梯度下降法最小化交叉熵,对应的代码为trainStep = tf.train.GradientDescentOptimizer(0.01).minimize(crossEntropy)。

      然后初始化参数并训练,定义训练次数为1000,每次随机地选取100图像进行计算。

      最后对得到的模型使用测试数据进行评估,评估结果表明精度达到0.9148(每次都不一样,在91%左右徘徊)。

      至此,我们采用最简单的一个全连接网络实现了一个手写数字识别的网络,剩下的工作是将这个网络及参数保存,采用自己的图片进行识别,进一步感受这个网络的效果,这一部分将在后续的工作中进行。同时我们可以说这个网络过于简单了,91%的识别效果也远远达不到我们的需求,如何进一步提高网络的精度是我们关注的重点。


    关于会话

      会话(session)提供在图中执行操作的一些方法。一般的模式是:

    1.  建立会话,此时会生成一张空图;
    2.  在会话中添加节点和边,形成一张图;
    3.  执行图

      在调用Session对象的run()方法来执行图时,传入一些Tensor,这个过程叫填充(feed);返回的结果类型根据输入的类型而定,这个过程叫取回(fetch)。

      会话是图交互的桥梁,一个会话可以有多个图,会话可以修改图的结构,也可以往图中注入数据进行计算。因此,会话主要由两个API接口--Extend和Run。Extend操作是在Graph中添加节点和边,Run操作是输入计算的节点和填充必要的数据后,进行计算,并输出运算结果。

    关于节点与图

      图中的节点又称为算子,它代表一个操作(Operation,op),一般用来表示施加的数学运算,也可以表示数据输入(feed in)的起点以及输出(push out)的终点,或者是读取/写入持久变量(persistent variable)的终点。

      如果不显式添加一个默认图,系统会自动设置一个全局的默认图。所设置的默认图,在模块范围内定义的节点都将默认加入默认图中。

    关于可视化

      可视化时,需要在程序中给必要的节点添加摘要(summary),摘要会收集该节点的数据,并标记上第几步、时间戳等标识,写入事件文件(event file)中。

    模型存储与加载

      TensorFLow的API提供了两种方式存储和加载模型:

      (1)生成检查点文件,拓展名一般为.ckpt,通过tf.train.Saver.save()生成。它包含权重和程序中定义的变量,不包含图结构。如果需要在另一个程序中使用,需要重新构建图结构,并告诉TensorFlow如何处理这些权重。

      (2)生成图协议文件,这是一个二进制文件,拓展名一般为.pb,用tf.train.write_graph()保存,只包含图形结构,不包含权重,然后使用tf.import_graph_def()来加载图形。

    模型训练之Momentum

      Momentum是模拟物理学中的动量的概念,更新时在一定程度上保留之前的更新方向,利用当前的批次再微调本次的更新参数,因此引入了一个新的变量v(速度),作为前几次梯度的累加。因此,Momentum能够改善训练过程,在下降初期,前后梯度一致时,能够加速学习;在下降的中后期,在局部最小值附近来回震荡时,能够抑制震荡,加快收敛。

  • 相关阅读:
    【7】用Laravel5.1开发一个简单的博客系统
    【6】Laravel5.1的migration数据库迁移
    【5】说说Laravel5的blade模板
    【4】优化一下【3】的例子,顺便说说细节
    【3】创建一个简单的Laravel例子
    【2】最简单的Laravel5.1程序分析
    【1】Laravel5.1 安装
    【0】Laravel 5.1 简介
    MySQL常用命令
    Windows8.1使用博客客户端写博客
  • 原文地址:https://www.cnblogs.com/roboai/p/7792954.html
Copyright © 2011-2022 走看看