zoukankan      html  css  js  c++  java
  • TensorFlow入门测试程序

     1 import tensorflow as tf
     2 from tensorflow.examples.tutorials.mnist import input_data
     3 
     4 mnist=input_data.read_data_sets("MNIST_data/",one_hot=True)
     5 
     6 # print(mnist.train.images.shape,mnist.train.labels.shape)
     7 # print(mnist.test.images.shape,mnist.test.labels.shape)
     8 # print(mnist.validation.images.shape,mnist.validation.labels.shape)
     9 
    10 sess=tf.InteractiveSession()
    11 x=tf.placeholder(tf.float32,[None,784])
    12 
    13 W=tf.Variable(tf.zeros([784,10]))
    14 b=tf.Variable(tf.zeros([10]))
    15 
    16 y=tf.nn.softmax(tf.matmul(x,W)+b)
    17 
    18 y_=tf.placeholder(tf.float32,[None,10])
    19 cross_entropy=tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y),reduction_indices=[1]))
    20 
    21 train_step=tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
    22 tf.initialize_all_variables().run()
    23 
    24 for i in range(1000):
    25     batch_xs,batch_ys=mnist.train.next_batch(100)
    26     train_step.run({x:batch_xs,y_:batch_ys})
    27 
    28 correct_prediction=tf.equal(tf.arg_max(y,1),tf.arg_max(y_,1))
    29 accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
    30 
    31 print(accuracy.eval({x:mnist.test.images,y_:mnist.test.labels}))
    View Code

     运行结果:

  • 相关阅读:
    474. Ones and Zeroes
    [LeetCode]464. Can I Win
    413. Arithmetic Slices
    numpy学习(布尔型索引)
    numpy学习(数组和标量之间的运算切片)
    numpy学习(数组的定义及基础属性)
    关于静态显示游标的遍历
    关于oracle的数组
    shutil模块
    开源库(不定义更新)
  • 原文地址:https://www.cnblogs.com/acm-jing/p/8490765.html
Copyright © 2011-2022 走看看