import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data numClasses=10 inputsize=784 numHiddenUnits=50 trainningIterations=50000#total steps batchSize=64# #1.dataset mnist=input_data.read_data_sets('data/',one_hot=True) ############################################################ #2.tarin X=tf.placeholder(tf.float32,shape=[None,inputsize]) y=tf.placeholder(tf.float32,shape=[None,numClasses]) #2.1 initial paras #y1=X*W1+B1 W1=tf.Variable(tf.truncated_normal([inputsize,numHiddenUnits],stddev=0.1)) B1=tf.Variable(tf.constant(0.1),[numHiddenUnits]) #y=y1*W2+B2 W2=tf.Variable(tf.truncated_normal([numHiddenUnits,numClasses],stddev=0.1)) B2=tf.Variable(tf.constant(0.1),[numClasses]) #layers hiddenLayerOutput=tf.nn.relu(tf.matmul(X,W1)+B1) finalOutput=tf.nn.relu(tf.matmul(hiddenLayerOutput,W2)+B2) #2.2 tarin set up loss=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=finalOutput)) opt=tf.train.GradientDescentOptimizer(learning_rate=0.1).minimize(loss) correct_prediction=tf.equal(tf.argmax(finalOutput,1),tf.argmax(y,1)) accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) #2.3 run tarin sess=tf.Session() init=tf.global_variables_initializer() sess.run(init) for i in range(trainningIterations): batch=mnist.train.next_batch(batchSize) batchInput=batch[0] batchLabels=batch[1] sess.run(opt,feed_dict={X:batchInput,y:batchLabels}) if i%1000 == 0: train_accuracy=sess.run(accuracy,feed_dict={X:batchInput,y:batchLabels}) print("step %d, tarinning accuracy %g" % (i,train_accuracy)) #2.4 run test to accuracy batch=mnist.test.next_batch(batchSize) testAccuracy=sess.run(accuracy,feed_dict={X:batch[0],y:batch[1]}) print("test accuracy %g" % (testAccuracy))
输出结果:
step 0, tarinning accuracy 0.171875 step 1000, tarinning accuracy 0.84375 step 2000, tarinning accuracy 0.953125 step 3000, tarinning accuracy 0.84375 step 4000, tarinning accuracy 0.953125 step 5000, tarinning accuracy 1 step 6000, tarinning accuracy 0.984375 step 7000, tarinning accuracy 1 step 8000, tarinning accuracy 0.984375 step 9000, tarinning accuracy 1 step 10000, tarinning accuracy 1 step 11000, tarinning accuracy 0.96875 step 12000, tarinning accuracy 1 step 13000, tarinning accuracy 0.96875 step 14000, tarinning accuracy 1 step 15000, tarinning accuracy 0.984375 step 16000, tarinning accuracy 0.953125 step 17000, tarinning accuracy 1 step 18000, tarinning accuracy 1 step 19000, tarinning accuracy 1 step 20000, tarinning accuracy 1 step 21000, tarinning accuracy 1 step 22000, tarinning accuracy 1 step 23000, tarinning accuracy 1 step 24000, tarinning accuracy 1 step 25000, tarinning accuracy 1 step 26000, tarinning accuracy 1 step 27000, tarinning accuracy 1 step 28000, tarinning accuracy 1 step 29000, tarinning accuracy 1 step 30000, tarinning accuracy 1 step 31000, tarinning accuracy 1 step 32000, tarinning accuracy 1 step 33000, tarinning accuracy 1 step 34000, tarinning accuracy 1 step 35000, tarinning accuracy 1 step 36000, tarinning accuracy 1 step 37000, tarinning accuracy 1 step 38000, tarinning accuracy 1 step 39000, tarinning accuracy 1 step 40000, tarinning accuracy 0.984375 step 41000, tarinning accuracy 1 step 42000, tarinning accuracy 1 step 43000, tarinning accuracy 1 step 44000, tarinning accuracy 1 step 45000, tarinning accuracy 1 step 46000, tarinning accuracy 1 step 47000, tarinning accuracy 1 step 48000, tarinning accuracy 1 step 49000, tarinning accuracy 1 test accuracy 0.984375