zoukankan      html  css  js  c++  java
  • 记录MNIST实现与理解

    之前半个月的时间几乎都在看理论书籍,最近两天开始撸代码,一个跟Hello World同级别的教程例子就出来了,那就是MNIST。实现代码应该很多地方都有:

     1 #!/usr/bin/env python
     2 # -*- coding: utf-8 -*-
     3 
     4 # @Author  : mario
     5 # @File    : mnist_main.py
     6 # @Project : base
     7 # @Time    : 2018-12-18 22:56:38
     8 # @Desc    : File is ...
     9 
    10 
    11 import tensorflow as tf
    12 from tensorflow.examples.tutorials.mnist import input_data
    13 
    14 mnist = input_data.read_data_sets("data/", one_hot=True)
    15 
    16 x = tf.placeholder(tf.float32, [None, 784], "image")
    17 W = tf.Variable(tf.zeros([784, 10]), name="weight")
    18 b = tf.Variable(tf.zeros([10]), name="bias")
    19 
    20 y = tf.nn.softmax(tf.matmul(x, W) + b)
    21 
    22 y_ = tf.placeholder(tf.float32, [None, 10])
    23 
    24 cross_entropy = -tf.reduce_sum(y_ * tf.log(y))
    25 
    26 train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
    27 
    28 init = tf.global_variables_initializer()
    29 
    30 sess = tf.Session()
    31 
    32 sess.run(init)
    33 
    34 for _ in range(1000):
    35     batch_xs, batch_ys = mnist.train.next_batch(100)
    36     sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
    37 
    38 correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
    39 
    40 correct_rate = tf.reduce_mean(tf.cast(correct_prediction, "float"))
    41 
    42 print(sess.run(correct_rate, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))

    python版本:Python 3.6.5 (v3.6.5:f59c0932b4, Mar 28 2018, 05:52:31) ;tensor flow版本:1.12.0

    11行,12行:导入tensorflow模块和读取数据的模块

    14行:读取当前目录下data目录下的数据,在data目录下应该是下载好的数据文件(train-images-idx3-ubyte.gz、train-labels-idx1-ubyte.gz、10k-images-idx3-ubyte.gz、t10k-labels-idx1-ubyte.gz)

    16行,17行,18行:为了构建模型方程设置的参数,模型方程是:y = xW+b,其中,x是自变量也就是输入参数,W是权重值,b是偏置量,所以在定义时,为一个占位符需要后期输入;W和b设置为变量参数,因为会随着训练而改变。

    20行:构建整个算法模型关于softmax函数可以查一查,简单来说就是一种结果转换。y是每次训练的结果。

    22行:y_是测试数据的对应标签。设置为占位符是因为我们需要输入标准的测试数据的标签。

    24行:计算交叉熵

    26行:实现梯度下降,学习率为0.01,学习率大小直接影响成功率和训练时间

    28行:初始化

    30行,32行:使用session提交执行图

    34行:设置训练次数1000次

    35行:读取训练数据,每次100个

    36行:开始运行训练模型,其中feed_dict是为占位变量设置值

    38行,40行:比较标准标签结果和测试结果,计算成功率

    42行:运行计算成功率,将带测试的数据赋值给占位符

    整个过程是不很复杂,其中一些算法的实现原理资料上都随便找得到。不过在这其中遇到过两个异常:

    1:没有导入from tensorflow.examples.tutorials.mnist import input_data,而是将tensorflow.examples.tutorials.mnist 当作了14行的mnist使用,在执行到35行时抛出异常,异常为:AttributeError: module 'tensorflow.examples.tutorials.mnist.mnist' has no attribute 'train',一开始一直不知道原因,后来发给朋友,她告诉我没有导入数据集,突然才意识到真的忘记导入数据集了。

    2:异常信息:
    InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'Placeholder' with dtype float and shape [?,10]
         [[node Placeholder (defined at /Users/mario/CodeRepository/PycharmProjects/base/cn/mario/tensorflow/mnist/mnist_main.py:22)  = Placeholder[dtype=DT_FLOAT, shape=[?,10], _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]

    网上貌似也有一些人遇到了这个问题,开始也是很郁闷的,错误信息是没有给占位符赋值,placeholder需要一个类型为float形状为[?,10],我尝试了改变几种类型,但是如果类型不一致的话,错误信息会很直接的说明,它需要的是类型A,而你给了类型B,更改shape也是如此,所以开始感觉应该是不是类型不对dtype和shape的匹配问题,而是占位符就根本没有赋值,于是我检查了两个占位符,发现是我在36行将y_误写为了y。

    可能很多人写在这个例子的时候没有遇到什么问题,但我觉得遇到问题也不是坏事,遇到问题,解决问题,能理解的更多一些。

  • 相关阅读:
    MYSQL中排序
    编写一个 SQL 查询,获取 Employee 表中第二高的薪水(Salary)
    job1
    python中对于数组的操作
    python中将字符串转为字典类型
    python中的周几
    python 链接redis 获取对应的值
    jenkins 设置定时任务规则
    如何安全close go 的channel
    [转]
  • 原文地址:https://www.cnblogs.com/ben-mario/p/10142417.html
Copyright © 2011-2022 走看看