zoukankan      html  css  js  c++  java
  • tensorflow学习笔记————分类MNIST数据集

    在使用tensorflow分类MNIST数据集中,最容易遇到的问题是下载MNIST样本的问题。

    一般是通过使用tensorflow内置的函数进行下载和加载,

    from tensorflow.examples.tutorials.mnist import input_data
    mnist = input_data.read_data_sets("MNIST_data", one_hot=True)

    但是我使用时遇到了“urllib.error.URLError: <urlopen error [Errno 99] Cannot assign requested address>”错误,查了一下也没什么好的解决方案,最后就自己去手动下载了。在python文件同目录下建立MNIST_data,进入目录后通过wget来下载

    wget http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
    wget http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
    wget http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
    wget http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz

    最后运行我们的程序

     1 import tensorflow as tf
     2 from tensorflow.examples.tutorials.mnist import input_data
     3 
     4 #通过tensorflow的库来载入训练的样本
     5 mnist = input_data.read_data_sets("MNIST_data", one_hot=True)
     6 
     7 #每个批次的大小
     8 batch_size = 100
     9 
    10 #计算有多少批次
    11 n_batch = mnist.train.num_examples // batch_size
    12 
    13 #定义两个placeholder,x是图片样本,y是输出的结果
    14 x = tf.placeholder(tf.float32, [None,784])
    15 y = tf.placeholder(tf.float32, [None,10])
    16 
    17 #创建一个简单的神经网络
    18 W = tf.Variable(tf.zeros([784,10]))
    19 b = tf.Variable(tf.zeros([10]))
    20 prediction = tf.nn.softmax(tf.matmul(x,W)+b)
    21 
    22 #二次代价函数
    23 loss = tf.reduce_mean(tf.square(y - prediction))
    24 
    25 #使用梯度下降法
    26 train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss)
    27 
    28 #初始化变量
    29 init = tf.global_variables_initializer()
    30 
    31 #结果存放在一个布尔类型列表中, tf.argmax返回一维张量中最大的值所在的位置,就是返回识别出来最可能的结果
    32 correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(prediction,1))
    33 
    34 #求准确率,tf.case()把bool转化为float
    35 accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    36 
    37 with tf.Session() as sess:
    38     sess.run(init)
    39     for epoch in range(21):
    40         for batch in range(n_batch):
    41             batch_xs,batch_ys = mnist.train.next_batch(batch_size)
    42             sess.run(train_step, feed_dict={x:batch_xs, y:batch_ys})
    43     
    44         acc = sess.run(accuracy, feed_dict={x:mnist.test.images, y:mnist.test.labels})
    45         print("Iter " + str(epoch) + ", Testing Accuracy" + str(acc))
    46     

  • 相关阅读:
    Unity打包ARCore项目失败,但是其他安卓项目成功
    关于Unity 图片队列存储以及出列导致内存溢出的解决方案
    unity 使用 outline 组件
    7Z解压工具的BUG
    Unity ILRuntime 调用方法一览
    Python 免费插件
    SQL经典面试题及答案
    PL/SQL Developer中文注释乱码的解决办法
    Tomcat并发优化和缓存优化
    在配置hibernate.cfg.xml时需指定使用数据库的方言:
  • 原文地址:https://www.cnblogs.com/QKSword/p/8723677.html
Copyright © 2011-2022 走看看