zoukankan      html  css  js  c++  java
  • tensorflow学习笔记————分类MNIST数据集

    在使用tensorflow分类MNIST数据集中,最容易遇到的问题是下载MNIST样本的问题。

    一般是通过使用tensorflow内置的函数进行下载和加载,

    from tensorflow.examples.tutorials.mnist import input_data
    mnist = input_data.read_data_sets("MNIST_data", one_hot=True)

    但是我使用时遇到了“urllib.error.URLError: <urlopen error [Errno 99] Cannot assign requested address>”错误,查了一下也没什么好的解决方案,最后就自己去手动下载了。在python文件同目录下建立MNIST_data,进入目录后通过wget来下载

    wget http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
    wget http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
    wget http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
    wget http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz

    最后运行我们的程序

     1 import tensorflow as tf
     2 from tensorflow.examples.tutorials.mnist import input_data
     3 
     4 #通过tensorflow的库来载入训练的样本
     5 mnist = input_data.read_data_sets("MNIST_data", one_hot=True)
     6 
     7 #每个批次的大小
     8 batch_size = 100
     9 
    10 #计算有多少批次
    11 n_batch = mnist.train.num_examples // batch_size
    12 
    13 #定义两个placeholder,x是图片样本,y是输出的结果
    14 x = tf.placeholder(tf.float32, [None,784])
    15 y = tf.placeholder(tf.float32, [None,10])
    16 
    17 #创建一个简单的神经网络
    18 W = tf.Variable(tf.zeros([784,10]))
    19 b = tf.Variable(tf.zeros([10]))
    20 prediction = tf.nn.softmax(tf.matmul(x,W)+b)
    21 
    22 #二次代价函数
    23 loss = tf.reduce_mean(tf.square(y - prediction))
    24 
    25 #使用梯度下降法
    26 train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss)
    27 
    28 #初始化变量
    29 init = tf.global_variables_initializer()
    30 
    31 #结果存放在一个布尔类型列表中, tf.argmax返回一维张量中最大的值所在的位置,就是返回识别出来最可能的结果
    32 correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(prediction,1))
    33 
    34 #求准确率,tf.case()把bool转化为float
    35 accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    36 
    37 with tf.Session() as sess:
    38     sess.run(init)
    39     for epoch in range(21):
    40         for batch in range(n_batch):
    41             batch_xs,batch_ys = mnist.train.next_batch(batch_size)
    42             sess.run(train_step, feed_dict={x:batch_xs, y:batch_ys})
    43     
    44         acc = sess.run(accuracy, feed_dict={x:mnist.test.images, y:mnist.test.labels})
    45         print("Iter " + str(epoch) + ", Testing Accuracy" + str(acc))
    46     

  • 相关阅读:
    h5实现 微信的授权登录
    js判断浏览器的环境(pc端,移动端,还是微信浏览器)
    动态判断时间插件显示到年月日时分秒
    H5发起微信支付
    Vue项目结合vux使用
    Swift学习笔记一:常量和变量
    iOS开发之解决系统数字键盘无文字时delete键无法监听的技巧
    Swift3.0之获取设备识别号deviceNo和保存账户AccountId
    Swift3.0之自定义debug阶段控制台打印
    Xcode之command+/快捷键添加注释不起作用
  • 原文地址:https://www.cnblogs.com/QKSword/p/8723677.html
Copyright © 2011-2022 走看看