pip install numpy -i http://pypi.douban.com/simple/ --trusted-host pypi.douban.com
pip install tensorflow-gpu==1.15.0 -i http://pypi.douban.com/simple/ --trusted-host pypi.douban.com
pip install opencv-python -i http://pypi.douban.com/simple/ --trusted-host pypi.douban.com
import time
import tensorflow as tf
import cv2 as cv
import numpy as np
def generate_image(a, rnd_size=100):
image = np.zeros([28, 28], dtype=np.uint8)
cv.putText(image, str(a), (7, 21), cv.FONT_HERSHEY_PLAIN, 1.3, 255, 2, 8)
for i in range(rnd_size):
row = np.random.randint(0, 28)
col = np.random.randint(0, 28)
image[row, col] = 0
data = np.reshape(image, [1, 784])
return image, data / 255
def display_images(images):
import matplotlib.pyplot as plt
size = len(images)
for i in range(size):
plt.subplot(2, 5, i + 1)
plt.imshow(images[i])
plt.show()
def load_data(sess, rnd_size=100, should_display_images=False):
zero_image, zero = generate_image(0, rnd_size)
one_image, one = generate_image(1, rnd_size)
two_image, two = generate_image(2, rnd_size)
three_image, three = generate_image(3, rnd_size)
four_image, four = generate_image(4, rnd_size)
five_image, five = generate_image(5, rnd_size)
six_image, six = generate_image(6, rnd_size)
seven_image, seven = generate_image(7, rnd_size)
eight_image, eight = generate_image(8, rnd_size)
nine_image, nine = generate_image(9, rnd_size)
if should_display_images is True:
display_images(
[zero_image, one_image, two_image, three_image, four_image, five_image, six_image, seven_image, eight_image,
nine_image])
x_features = [zero, one, two, three, four, five, six, seven, eight, nine]
x_features = np.array(x_features)
x_features = np.reshape(x_features, (-1, 784))
y = None
y_lables = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
y = sess.run(tf.one_hot(y_lables, 10))
return x_features, y
def build_network(nhidden, classes_count):
x = tf.placeholder(tf.float32, shape=[None, 784], name='x')
y = tf.placeholder(tf.float32, shape=[None, classes_count], name='y')
W1 = tf.Variable(tf.random_normal([784, nhidden]))
b1 = tf.Variable(tf.random_normal([1, nhidden]))
hidden1 = tf.add(tf.matmul(x, W1), b1)
hidden1_result = tf.sigmoid(hidden1)
W2 = tf.Variable(tf.random_normal([nhidden, classes_count]))
b2 = tf.Variable(tf.random_normal([1, classes_count]))
out = tf.add(tf.matmul(hidden1_result, W2), b2)
out_result = tf.sigmoid(out)
diff = tf.subtract(out_result, y)
loss = tf.reduce_sum(tf.square(diff))
train = tf.train.GradientDescentOptimizer(0.1)
step = train.minimize(loss)
tf.summary.scalar("loss", loss)
return x, y, out_result, loss, step
def do_train():
x, y, y_, loss, step = build_network(10, 10)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
summary_merged = tf.summary.merge_all()
writer = tf.summary.FileWriter('logs-'+str(time.time()), sess.graph)
for i in range(800):
x_features, y_labels = load_data(sess)
sess.run(step, feed_dict={x: x_features, y: y_labels})
if (i + 1) % 50 == 0:
cur_loss, summary_ = sess.run([loss, summary_merged], feed_dict={x: x_features, y: y_labels})
writer.add_summary(summary_, i)
pred_ys = sess.run(y_, feed_dict={x: x_features, y: y_labels})
ys = tf.argmax(pred_ys, 0)
ys_correct = tf.argmax(y_labels, 0)
c = tf.equal(ys, ys_correct)
count = tf.reduce_sum(tf.cast(c, tf.float32))
r = sess.run(count)
print(i + 1, ': loss: ', cur_loss, '正确个数:', r)
print('*************************')
x_features, y_labels = load_data(sess, 150, should_display_images=True)
pred_ys = sess.run(y_, feed_dict={x: x_features})
ys = tf.argmax(pred_ys, 0)
r = sess.run(ys)
print('图片识别结果:', r)
writer.close()
if __name__ == '__main__':
do_train()
50 : loss: 7.3588676 正确个数: 4.0
100 : loss: 6.6502814 正确个数: 5.0
150 : loss: 5.26784 正确个数: 7.0
200 : loss: 4.0591483 正确个数: 9.0
250 : loss: 3.4379258 正确个数: 8.0
300 : loss: 3.114149 正确个数: 8.0
350 : loss: 2.0274947 正确个数: 9.0
400 : loss: 1.4823446 正确个数: 10.0
450 : loss: 1.4051719 正确个数: 10.0
500 : loss: 0.91150457 正确个数: 10.0
550 : loss: 0.7835213 正确个数: 10.0
600 : loss: 0.72512466 正确个数: 10.0
650 : loss: 0.56525075 正确个数: 10.0
700 : loss: 0.4699742 正确个数: 10.0
750 : loss: 0.45453963 正确个数: 10.0
800 : loss: 0.45089394 正确个数: 10.0
*************************
图片识别结果: [0 1 2 3 4 5 6 7 8 9]