pip install numpy -i http://pypi.douban.com/simple/ --trusted-host pypi.douban.com pip install tensorflow-gpu==1.15.0 -i http://pypi.douban.com/simple/ --trusted-host pypi.douban.com pip install opencv-python -i http://pypi.douban.com/simple/ --trusted-host pypi.douban.com
import time import tensorflow as tf import cv2 as cv import numpy as np def generate_image(a, rnd_size=100): image = np.zeros([28, 28], dtype=np.uint8) cv.putText(image, str(a), (7, 21), cv.FONT_HERSHEY_PLAIN, 1.3, 255, 2, 8) for i in range(rnd_size): row = np.random.randint(0, 28) col = np.random.randint(0, 28) image[row, col] = 0 data = np.reshape(image, [1, 784]) return image, data / 255 def display_images(images): import matplotlib.pyplot as plt size = len(images) for i in range(size): plt.subplot(2, 5, i + 1) plt.imshow(images[i]) plt.show() def load_data(sess, rnd_size=100, should_display_images=False): zero_image, zero = generate_image(0, rnd_size) one_image, one = generate_image(1, rnd_size) two_image, two = generate_image(2, rnd_size) three_image, three = generate_image(3, rnd_size) four_image, four = generate_image(4, rnd_size) five_image, five = generate_image(5, rnd_size) six_image, six = generate_image(6, rnd_size) seven_image, seven = generate_image(7, rnd_size) eight_image, eight = generate_image(8, rnd_size) nine_image, nine = generate_image(9, rnd_size) if should_display_images is True: display_images( [zero_image, one_image, two_image, three_image, four_image, five_image, six_image, seven_image, eight_image, nine_image]) x_features = [zero, one, two, three, four, five, six, seven, eight, nine] x_features = np.array(x_features) x_features = np.reshape(x_features, (-1, 784)) y = None y_lables = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] y = sess.run(tf.one_hot(y_lables, 10)) return x_features, y def build_network(nhidden, classes_count): x = tf.placeholder(tf.float32, shape=[None, 784], name='x') y = tf.placeholder(tf.float32, shape=[None, classes_count], name='y') W1 = tf.Variable(tf.random_normal([784, nhidden])) b1 = tf.Variable(tf.random_normal([1, nhidden])) hidden1 = tf.add(tf.matmul(x, W1), b1) hidden1_result = tf.sigmoid(hidden1) W2 = tf.Variable(tf.random_normal([nhidden, classes_count])) b2 = tf.Variable(tf.random_normal([1, classes_count])) out = tf.add(tf.matmul(hidden1_result, W2), b2) out_result = tf.sigmoid(out) diff = tf.subtract(out_result, y) loss = tf.reduce_sum(tf.square(diff)) train = tf.train.GradientDescentOptimizer(0.1) step = train.minimize(loss) tf.summary.scalar("loss", loss) return x, y, out_result, loss, step def do_train(): x, y, y_, loss, step = build_network(10, 10) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) summary_merged = tf.summary.merge_all() writer = tf.summary.FileWriter('logs-'+str(time.time()), sess.graph) for i in range(800): x_features, y_labels = load_data(sess) sess.run(step, feed_dict={x: x_features, y: y_labels}) if (i + 1) % 50 == 0: cur_loss, summary_ = sess.run([loss, summary_merged], feed_dict={x: x_features, y: y_labels}) writer.add_summary(summary_, i) pred_ys = sess.run(y_, feed_dict={x: x_features, y: y_labels}) ys = tf.argmax(pred_ys, 0) ys_correct = tf.argmax(y_labels, 0) c = tf.equal(ys, ys_correct) count = tf.reduce_sum(tf.cast(c, tf.float32)) r = sess.run(count) print(i + 1, ': loss: ', cur_loss, '正确个数:', r) print('*************************') x_features, y_labels = load_data(sess, 150, should_display_images=True) pred_ys = sess.run(y_, feed_dict={x: x_features}) ys = tf.argmax(pred_ys, 0) r = sess.run(ys) print('图片识别结果:', r) writer.close() if __name__ == '__main__': do_train()
50 : loss: 7.3588676 正确个数: 4.0
100 : loss: 6.6502814 正确个数: 5.0
150 : loss: 5.26784 正确个数: 7.0
200 : loss: 4.0591483 正确个数: 9.0
250 : loss: 3.4379258 正确个数: 8.0
300 : loss: 3.114149 正确个数: 8.0
350 : loss: 2.0274947 正确个数: 9.0
400 : loss: 1.4823446 正确个数: 10.0
450 : loss: 1.4051719 正确个数: 10.0
500 : loss: 0.91150457 正确个数: 10.0
550 : loss: 0.7835213 正确个数: 10.0
600 : loss: 0.72512466 正确个数: 10.0
650 : loss: 0.56525075 正确个数: 10.0
700 : loss: 0.4699742 正确个数: 10.0
750 : loss: 0.45453963 正确个数: 10.0
800 : loss: 0.45089394 正确个数: 10.0
*************************
图片识别结果: [0 1 2 3 4 5 6 7 8 9]