zoukankan      html  css  js  c++  java
  • TensorFlow训练MNIST数据集(1) —— softmax 单层神经网络

    1、MNIST数据集简介

      首先通过下面两行代码获取到TensorFlow内置的MNIST数据集:

    from tensorflow.examples.tutorials.mnist import input_data
    
    mnist = input_data.read_data_sets('./data/mnist', one_hot=True)

      MNIST数据集共有55000(mnist.train.num_examples)张用于训练的数据,对应的有55000个标签;共有10000(mnist.test.num_examples)张用于测试的图片的数据,同样的有10000个标签与之对应。为了方便访问,这些图片或标签的数据都是被格式化了的。

      MNIST数据集的训练数据集(mnist.train.images)是一个 55000 * 784 的矩阵,矩阵的每一行代表一张图片(28 * 28 * 1)的数据,图片的数据范围是 [0, 1],代表像素点灰度归一化后的值。

      训练集的标签(mnist.train.labels)是一个55000 * 10 的矩阵,每一行的10个数字分别代表对应的图片属于数字0到9的概率,范围是0或1。一个标签行只有一个是1,表示该图片的正确数字是对应的下标值, 其余是0。

      测试集与训练集的类似,只是数据量不同。

      以下代码显示部分MNIST训练图片的形状及标签:

    import numpy as np
    import matplotlib.pyplot as plot
    from tensorflow.examples.tutorials.mnist import input_data
    
    mnist = input_data.read_data_sets('./data/mnist', one_hot=True)
    trainImages = mnist.train.images
    trainLabels = mnist.train.labels
    
    plot.figure(1, figsize=(4, 3))
    for i in range(6):
        curImage = np.reshape(trainImages[i, :], (28, 28))
        curLabel = np.argmax(trainLabels[i, :])
        ax = plot.subplot(int(str(23) + str(i+1)))
        plot.imshow(curImage, cmap=plot.get_cmap('gray'))
        plot.axis('off')
        ax.set_title(curLabel)
    
    plot.suptitle('MNIST')
    plot.show()

      上述代码输出的MNIST图片及其标签:

     2、通过单层神经网络进行训练

     1 def train(trainCycle=50000, debug=False):
     2     inputSize  = 784
     3     outputSize = 10
     4     batchSize  = 64
     5     inputs = tf.placeholder(tf.float32, shape=[None, inputSize])
     6 
     7     # x * w = [64, 784] * [784, 10]
     8     weights   = tf.Variable(tf.random_normal([784, 10], 0, 0.1))
     9     bias      = tf.Variable(tf.random_normal([outputSize], 0, 0.1))
    10     outputs   = tf.add(tf.matmul(inputs, weights), bias)
    11     outputs   = tf.nn.softmax(outputs)
    12 
    13     labels = tf.placeholder(tf.float32, shape=[None, outputSize])
    14 
    15     loss      = tf.reduce_mean(tf.square(outputs - labels))
    16     optimizer = tf.train.GradientDescentOptimizer(0.1)
    17     trainer   = optimizer.minimize(loss)
    18 
    19     sess = tf.Session()
    20     sess.run(tf.global_variables_initializer())
    21     for i in range(trainCycle):
    22         batch = mnist.train.next_batch(batchSize)
    23         sess.run([trainer, loss], feed_dict={inputs: batch[0], labels: batch[1]})
    24 
    25         if debug and i % 1000 == 0:
    26             corrected = tf.equal(tf.argmax(labels, 1), tf.argmax(outputs, 1))
    27             accuracy = tf.reduce_mean(tf.cast(corrected, tf.float32))
    28             accuracyValue = sess.run(accuracy, feed_dict={inputs: batch[0], labels: batch[1]})
    29             print(i, ' train set accuracy:', accuracyValue)
    30 
    31     # 测试
    32     corrected = tf.equal(tf.argmax(labels, 1), tf.argmax(outputs, 1))
    33     accuracy = tf.reduce_mean(tf.cast(corrected, tf.float32))
    34     accuracyValue = sess.run(accuracy, feed_dict={inputs: mnist.test.images, labels: mnist.test.labels})
    35     print("accuracy on test set:", accuracyValue)
    36 
    37     sess.close()

    3、训练结果

      上述模型的最终输出为:

    由打印日志可以看出,前期收敛速度很快,后期开始波动。最后该模型在训练集上的正确率大概为90%,测试集上也差不多。精度还是比较低的,说明单层的神经网络在处理图片数据上存在着很大的缺陷,并不是一个很好的选择。

    本文地址:https://www.cnblogs.com/laishenghao/p/9576806.html

  • 相关阅读:
    AngularJS中实现无限级联动菜单
    理解AngularJS生命周期:利用ng-repeat动态解析自定义directive
    denounce函数:Javascript中如何应对高频触发事件
    Javascript中的循环变量声明,到底应该放在哪儿?
    优雅的数组降维——Javascript中apply方法的妙用
    如何利⽤360Quake挖掘某授权⼚商边缘站点漏洞
    Java课程设计--网络聊天室
    DS博客作业08--课程总结
    DS博客作业07--查找
    DS博客作业06--图
  • 原文地址:https://www.cnblogs.com/laishenghao/p/9576806.html
Copyright © 2011-2022 走看看