Mnist字符识别-神经网络实现(TF框架)
该段代码即贴即用,先贴一下代码,有空的时候写个注释解析。大三的代码了,特别适合新手入门,现在都用Pytorch了。
电脑用的tensorflow版本是1.13.1的,用CPU跑也挺快的。之前用GPU跑了半小时准确率能达到98%左右。
代码
# -*- coding:utf-8 -*-
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from matplotlib import pyplot
import matplotlib.pyplot as plt
import numpy as np
mnist = input_data.read_data_sets("MNIST_data", one_hot=True)
seed=547
np.random.seed(seed)
epoch_time = 20;
ALPHY = 0.5
batch_size = 10
n_batch_all = mnist.train.num_examples // batch_size
n_batch = 1000 // batch_size
x = tf.placeholder(tf.float32,[None,784])
y = tf.placeholder(tf.float32,[None,10])
def xavier_init(size):
in_dim = size[0]
xavier_stddev = 1. / tf.sqrt(in_dim / 2.)
return tf.random_normal(shape=size, stddev=xavier_stddev)
W1 =tf.Variable(xavier_init([784, 30]))
B1 = tf.Variable(tf.zeros([30]))
L1 = tf.nn.sigmoid(tf.matmul(x,W1) + B1)
W2 =tf.Variable(xavier_init([30, 10]))
B2 = tf.Variable(tf.zeros([10]))
logit_prediction = tf.matmul(L1,W2) + B2
prediction = tf.nn.sigmoid(logit_prediction)
# MSE损失函数
# loss = tf.reduce_mean(tf.square(y - prediction))
#交叉熵损失函数
loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logit_prediction,labels=y)
train_setup = tf.train.GradientDescentOptimizer(ALPHY).minimize(loss)
init = tf.global_variables_initializer()
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
def getBatch(inputs):
np.random.shuffle(inputs)
batch = inputs[:10]
fina_x = batch[:, :784]
fina_y = batch[:, 784:794]
return fina_x, fina_y
def draw(train, text):
names = range(0, epoch_time)
names = [str(x) for x in list(names)]
x = range(len(names))
plt.plot(x, train, marker='o', mec='r', mfc='w', label='train_1000')
plt.plot(x, text, marker='*', ms=10, label='train_all')
plt.legend()
plt.xticks(x, names, rotation=1)
plt.margins(0)
plt.subplots_adjust(bottom=0.10)
plt.xlabel('epoch')
plt.ylabel("accuracy")
pyplot.yticks([0, 0.5, 1])
# plt.title("A simple plot")
plt.savefig('accuracy.jpg', dpi=900)
def train_1000():
sess.run(init)
train = tf.zeros(epoch_time)
# batch_xs_all, batch_ys_all = mnist.train.next_batch(1000);
# print("X shape:", batch_xs_all.shape)
# print("Y shape:", batch_ys_all.shape)
X_mb, Y_mb = mnist.train.next_batch(1000)
Y_mb = Y_mb.astype(np.float32)
inputs = tf.concat(axis=1, values=[X_mb, Y_mb])
inputs = inputs.eval(session=sess)
train = train.eval(session=sess)
for epoch in range(epoch_time):
for batch in range(n_batch):
fina_x, fina_y = getBatch(inputs)
# batch_xs,batch_ys=mnist.train.next_batch(batch_size)
sess.run(train_setup, feed_dict={x: fina_x, y: fina_y})
acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})
train[epoch] = acc;
print("Iter" + str(epoch) + ", Testing Accuracy=" + str(acc))
return train;
def train_all():
sess.run(init)
text = tf.zeros(epoch_time)
text = text.eval(session=sess)
for epoch in range(epoch_time):
for batch in range(n_batch_all):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
sess.run(train_setup, feed_dict={x: batch_xs, y: batch_ys})
acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})
text[epoch] = acc;
print("Iter" + str(epoch) + ", Testing Accuracy=" + str(acc))
with tf.Session() as sess:
p1 = train_1000();
p2 = train_all();
draw(p1, p2)
结果
Iter0, Testing Accuracy=0.0982
Iter1, Testing Accuracy=0.2913
Iter2, Testing Accuracy=0.2973
Iter3, Testing Accuracy=0.3493
Iter4, Testing Accuracy=0.4311
Iter5, Testing Accuracy=0.3789
Iter6, Testing Accuracy=0.49
Iter7, Testing Accuracy=0.4547
Iter8, Testing Accuracy=0.4079
Iter9, Testing Accuracy=0.4748
Iter10, Testing Accuracy=0.564
Iter11, Testing Accuracy=0.5026
Iter12, Testing Accuracy=0.6053
Iter13, Testing Accuracy=0.6379
Iter14, Testing Accuracy=0.5863
Iter15, Testing Accuracy=0.6443
Iter16, Testing Accuracy=0.6487
Iter17, Testing Accuracy=0.5809
Iter18, Testing Accuracy=0.6616
Iter19, Testing Accuracy=0.6465
Iter0, Testing Accuracy=0.7625
Iter1, Testing Accuracy=0.864
Iter2, Testing Accuracy=0.8596
Iter3, Testing Accuracy=0.8694
Iter4, Testing Accuracy=0.9028
Iter5, Testing Accuracy=0.9046
Iter6, Testing Accuracy=0.902
Iter7, Testing Accuracy=0.9021
Iter8, Testing Accuracy=0.8874
Iter9, Testing Accuracy=0.9192
Iter10, Testing Accuracy=0.9175
Iter11, Testing Accuracy=0.9226
Iter12, Testing Accuracy=0.9233
Iter13, Testing Accuracy=0.9156
Iter14, Testing Accuracy=0.93
Iter15, Testing Accuracy=0.9251
Iter16, Testing Accuracy=0.9232
Iter17, Testing Accuracy=0.9176
Iter18, Testing Accuracy=0.9287
Iter19, Testing Accuracy=0.9273