zoukankan      html  css  js  c++  java
  • TensorFlow 训练MNIST数据集(2)—— 多层神经网络

      在我的上一篇随笔中,采用了单层神经网络来对MNIST进行训练,在测试集中只有约90%的正确率。这次换一种神经网络(多层神经网络)来进行训练和测试。

    1、获取MNIST数据

      MNIST数据集只要一行代码就可以获取的到,非常方便。关于MNIST的基本信息可以参考我的上一篇随笔。

    mnist = input_data.read_data_sets('./data/mnist', one_hot=True)

    2、模型基本结构

      本次采用的训练模型为三层神经网络结构,输入层节点数与MNIST一行数据的长度一致,为784;输出层节点数与数字的类别数一致,为10;隐藏层节点数为50个;每次训练的mini-batch数量为64,;最大训练周期为50000。

    1 inputSize  = 784
    2 outputSize = 10
    3 hiddenSize = 50
    4 batchSize  = 64
    5 trainCycle = 50000

    3、输入层

      输入层用于接收每次小批量样本的输入,先通过placeholder来进行占位,在训练时才传入具体的数据。值得注意的是,在生成输入层的tensor时,传入的shape中有一个‘None’,表示每次输入的样本的数量,该‘None’表示先不作具体的指定,在真正输入的时候再根据实际的数据来进行推断。这个很方便,但也是有条件的,也就是通过该方法返回的tensor不能使用简单的加(+)减(-)乘(*)除(/)符号来进行计算(否则将会报错),需要用TensorFlow中的相关函数来进行代替。

    inputLayer = tf.placeholder(tf.float32, shape=[None, inputSize])

    4、隐藏层

      在神经网络中,隐藏层的作用主要是提取数据的特征(feature)。这里的权重参数采用了 tensorflow.truncated_normal() 函数来进行生成,与上次采用的 tensorflow.

    random_normal() 不一样。这两者的作用都是生成指定形状、期望和标准差的符合正太分布随机变量。区别是 truncated_normal 函数对随机变量的范围有个限制(与期望的偏差在2个标准差之内,否则丢弃)。另外偏差项这里也使用了变量的形式,也可以采用常量来进行替代。 

      激活函数为sigmoid函数。

    1 hiddenWeight = tf.Variable(tf.truncated_normal([inputSize, hiddenSize], mean=0, stddev=0.1))
    2 hiddenBias   = tf.Variable(tf.truncated_normal([hiddenSize]))
    3 hiddenLayer  = tf.add(tf.matmul(inputLayer, hiddenWeight), hiddenBias)
    4 hiddenLayer  = tf.nn.sigmoid(hiddenLayer)

    5、输出层

      输出层与隐藏层类似,只是节点数不一样。

    1 outputWeight = tf.Variable(tf.truncated_normal([hiddenSize, outputSize], mean=0, stddev=0.1))
    2 outputBias   = tf.Variable(tf.truncated_normal([outputSize], mean=0, stddev=0.1))
    3 outputLayer  = tf.add(tf.matmul(hiddenLayer, outputWeight), outputBias)
    4 outputLayer  = tf.nn.sigmoid(outputLayer)

    6、输出标签

      跟输入层一样,也是先占位,在最后训练的时候再传入具体的数据。标签,也就是每一个样本的正确分类。

    outputLabel = tf.placeholder(tf.float32, shape=[None, outputSize])

    7、损失函数

      这里采用的是交叉熵损失函数。注意用的是v2版本,第一个版本已被TensorFlow声明为deprecated,准备废弃了。

    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=outputLabel, logits=outputLayer))

    8、优化器与目标函数

      优化器采用了Adam梯度下降法,我试过了普通的GradientDescentOptimizer,效果不如Adam;也用过Adadelta,结果几乎收敛不了。

      目标函数就是最小化损失函数。

    optimizer = tf.train.AdamOptimizer()
    target    = optimizer.minimize(loss)

    9、训练过程

      先创建一个会话,然后初始化tensors,最后进行迭代训练。模型的收敛速度很快,在1000次的时候就达到了大概90%的正确率。

     1 with tf.Session() as sess:
     2     sess.run(tf.global_variables_initializer())
     3 
     4     for i in range(trainCycle):
     5         batch = mnist.train.next_batch(batchSize)
     6         sess.run(target, feed_dict={inputLayer: batch[0], outputLabel: batch[1]})
     7 
     8         if i % 1000 == 0:
     9             corrected = tf.equal(tf.argmax(outputLabel, 1), tf.argmax(outputLayer, 1))
    10             accuracy = tf.reduce_mean(tf.cast(corrected, tf.float32))
    11             accuracyValue = sess.run(accuracy, feed_dict={inputLayer: batch[0], outputLabel: batch[1]})
    12             print(i, 'train set accuracy:', accuracyValue)

    模型训练输出:

    10、测试训练结果

      在测数据集上测试。准确率达到96%,比单层的神经网络好很多。

    1     corrected = tf.equal(tf.argmax(outputLabel, 1), tf.argmax(outputLayer, 1))
    2     accuracy  = tf.reduce_mean(tf.cast(corrected, tf.float32))
    3     accuracyValue = sess.run(accuracy, feed_dict={inputLayer: mnist.test.images, outputLabel: mnist.test.labels})
    4     print("accuracy on test set:", accuracyValue)

    测试集上的输出:

     

    附:

      完整代码如下:

     1 import tensorflow as tf
     2 from tensorflow.examples.tutorials.mnist import input_data
     3 
     4 mnist = input_data.read_data_sets('./data/mnist', one_hot=True)
     5 
     6 inputSize  = 784
     7 outputSize = 10
     8 hiddenSize = 50
     9 batchSize  = 64
    10 trainCycle = 50000
    11 
    12 # 输入层
    13 inputLayer = tf.placeholder(tf.float32, shape=[None, inputSize])
    14 
    15 # 隐藏层
    16 hiddenWeight = tf.Variable(tf.truncated_normal([inputSize, hiddenSize], mean=0, stddev=0.1))
    17 hiddenBias   = tf.Variable(tf.truncated_normal([hiddenSize]))
    18 hiddenLayer  = tf.add(tf.matmul(inputLayer, hiddenWeight), hiddenBias)
    19 hiddenLayer  = tf.nn.sigmoid(hiddenLayer)
    20 
    21 # 输出层
    22 outputWeight = tf.Variable(tf.truncated_normal([hiddenSize, outputSize], mean=0, stddev=0.1))
    23 outputBias   = tf.Variable(tf.truncated_normal([outputSize], mean=0, stddev=0.1))
    24 outputLayer  = tf.add(tf.matmul(hiddenLayer, outputWeight), outputBias)
    25 outputLayer  = tf.nn.sigmoid(outputLayer)
    26 
    27 # 标签
    28 outputLabel = tf.placeholder(tf.float32, shape=[None, outputSize])
    29 
    30 # 损失函数
    31 loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=outputLabel, logits=outputLayer))
    32 
    33 # 优化器
    34 optimizer = tf.train.AdamOptimizer()
    35 
    36 # 训练目标
    37 target = optimizer.minimize(loss)
    38 
    39 # 训练
    40 with tf.Session() as sess:
    41     sess.run(tf.global_variables_initializer())
    42 
    43     for i in range(trainCycle):
    44         batch = mnist.train.next_batch(batchSize)
    45         sess.run(target, feed_dict={inputLayer: batch[0], outputLabel: batch[1]})
    46 
    47         if i % 1000 == 0:
    48             corrected = tf.equal(tf.argmax(outputLabel, 1), tf.argmax(outputLayer, 1))
    49             accuracy = tf.reduce_mean(tf.cast(corrected, tf.float32))
    50             accuracyValue = sess.run(accuracy, feed_dict={inputLayer: batch[0], outputLabel: batch[1]})
    51             print(i, 'train set accuracy:', accuracyValue)
    52 
    53     # 测试
    54     corrected = tf.equal(tf.argmax(outputLabel, 1), tf.argmax(outputLayer, 1))
    55     accuracy  = tf.reduce_mean(tf.cast(corrected, tf.float32))
    56     accuracyValue = sess.run(accuracy, feed_dict={inputLayer: mnist.test.images, outputLabel: mnist.test.labels})
    57     print("accuracy on test set:", accuracyValue)
    58 
    59     sess.close()
    View Code

    本文地址:https://www.cnblogs.com/laishenghao/p/9736696.html

  • 相关阅读:
    CCF认证201809-2买菜
    git删除本地保存的账号和密码
    mysql表分区
    使用java代码将时间戳和时间互相转换
    Mysql数据库表被锁定处理
    mysql查询某个数据库表的数量
    编译nginx错误:make[1]: *** [/pcre//Makefile] Error 127
    LINUX下安装pcre出现WARNING: 'aclocal-1.15' is missing on your system错误的解决办法
    linux下安装perl
    [剑指Offer]26-树的子结构
  • 原文地址:https://www.cnblogs.com/laishenghao/p/9736696.html
Copyright © 2011-2022 走看看