import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data
# 载入数据集 mnist = input_data.read_data_sets("MNIST_data", one_hot=True) # 批次大小 batch_size = 64 # 计算一个周期一共有多少个批次 n_batch = mnist.train.num_examples // batch_size # 定义两个placeholder x = tf.placeholder(tf.float32,[None,784]) y = tf.placeholder(tf.float32,[None,10]) # 创建一个简单的神经网络:784-10 W = tf.Variable(tf.truncated_normal([784,10], stddev=0.1)) b = tf.Variable(tf.zeros([10]) + 0.1) prediction = tf.nn.softmax(tf.matmul(x,W)+b) # 二次代价函数 # loss = tf.losses.mean_squared_error(y, prediction) # 交叉熵 loss = tf.losses.softmax_cross_entropy(y, prediction) # 使用梯度下降法 train = tf.train.GradientDescentOptimizer(0.3).minimize(loss) # 结果存放在一个布尔型列表中 correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1)) # 求准确率 accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) with tf.Session() as sess: # 变量初始化 sess.run(tf.global_variables_initializer()) # 周期epoch:所有数据训练一次,就是一个周期 for epoch in range(21): for batch in range(n_batch): # 获取一个批次的数据和标签 batch_xs,batch_ys = mnist.train.next_batch(batch_size) sess.run(train,feed_dict={x:batch_xs,y:batch_ys}) # 每训练一个周期做一次测试 acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels}) print("Iter " + str(epoch) + ",Testing Accuracy " + str(acc))
Extracting MNIST_data rain-images-idx3-ubyte.gz Extracting MNIST_data rain-labels-idx1-ubyte.gz Extracting MNIST_data 10k-images-idx3-ubyte.gz Extracting MNIST_data 10k-labels-idx1-ubyte.gz Iter 0,Testing Accuracy 0.7473 Iter 1,Testing Accuracy 0.8413 Iter 2,Testing Accuracy 0.9066 Iter 3,Testing Accuracy 0.9113 Iter 4,Testing Accuracy 0.9143 Iter 5,Testing Accuracy 0.9168 Iter 6,Testing Accuracy 0.9199 Iter 7,Testing Accuracy 0.9201 Iter 8,Testing Accuracy 0.9202 Iter 9,Testing Accuracy 0.9213 Iter 10,Testing Accuracy 0.921 Iter 11,Testing Accuracy 0.9205 Iter 12,Testing Accuracy 0.9214 Iter 13,Testing Accuracy 0.923 Iter 14,Testing Accuracy 0.9237 Iter 15,Testing Accuracy 0.9238 Iter 16,Testing Accuracy 0.924 Iter 17,Testing Accuracy 0.9231 Iter 18,Testing Accuracy 0.9246 Iter 19,Testing Accuracy 0.925 Iter 20,Testing Accuracy 0.9253