zoukankan      html  css  js  c++  java
  • tf识别单张图片ocr(0到9的识别)

    pip install numpy -i http://pypi.douban.com/simple/ --trusted-host pypi.douban.com
    pip install tensorflow-gpu==1.15.0 -i http://pypi.douban.com/simple/ --trusted-host pypi.douban.com
    pip install opencv-python -i http://pypi.douban.com/simple/ --trusted-host pypi.douban.com
    import time
    
    import tensorflow as tf
    import cv2 as cv
    import numpy as np
    
    
    def generate_image(a, rnd_size=100):
        image = np.zeros([28, 28], dtype=np.uint8)
        cv.putText(image, str(a), (7, 21), cv.FONT_HERSHEY_PLAIN, 1.3, 255, 2, 8)
    
        for i in range(rnd_size):
            row = np.random.randint(0, 28)
            col = np.random.randint(0, 28)
            image[row, col] = 0
    
        data = np.reshape(image, [1, 784])
        return image, data / 255
    
    
    def display_images(images):
        import matplotlib.pyplot as plt
        size = len(images)
        for i in range(size):
            plt.subplot(2, 5, i + 1)
            plt.imshow(images[i])
    
        plt.show()
    
    
    def load_data(sess, rnd_size=100, should_display_images=False):
        zero_image, zero = generate_image(0, rnd_size)
        one_image, one = generate_image(1, rnd_size)
        two_image, two = generate_image(2, rnd_size)
        three_image, three = generate_image(3, rnd_size)
        four_image, four = generate_image(4, rnd_size)
        five_image, five = generate_image(5, rnd_size)
        six_image, six = generate_image(6, rnd_size)
        seven_image, seven = generate_image(7, rnd_size)
        eight_image, eight = generate_image(8, rnd_size)
        nine_image, nine = generate_image(9, rnd_size)
    
        if should_display_images is True:
            display_images(
                [zero_image, one_image, two_image, three_image, four_image, five_image, six_image, seven_image, eight_image,
                 nine_image])
    
        x_features = [zero, one, two, three, four, five, six, seven, eight, nine]
        x_features = np.array(x_features)
        x_features = np.reshape(x_features, (-1, 784))
    
        y = None
        y_lables = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
        y = sess.run(tf.one_hot(y_lables, 10))
    
        return x_features, y
    
    
    def build_network(nhidden, classes_count):
        x = tf.placeholder(tf.float32, shape=[None, 784], name='x')
        y = tf.placeholder(tf.float32, shape=[None, classes_count], name='y')
    
        W1 = tf.Variable(tf.random_normal([784, nhidden]))
        b1 = tf.Variable(tf.random_normal([1, nhidden]))
        hidden1 = tf.add(tf.matmul(x, W1), b1)
        hidden1_result = tf.sigmoid(hidden1)
    
        W2 = tf.Variable(tf.random_normal([nhidden, classes_count]))
        b2 = tf.Variable(tf.random_normal([1, classes_count]))
        out = tf.add(tf.matmul(hidden1_result, W2), b2)
        out_result = tf.sigmoid(out)
    
        diff = tf.subtract(out_result, y)
        loss = tf.reduce_sum(tf.square(diff))
        train = tf.train.GradientDescentOptimizer(0.1)
        step = train.minimize(loss)
    
        tf.summary.scalar("loss", loss)
    
        return x, y, out_result, loss, step
    
    
    def do_train():
        x, y, y_, loss, step = build_network(10, 10)
    
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            summary_merged = tf.summary.merge_all()
            writer = tf.summary.FileWriter('logs-'+str(time.time()), sess.graph)
            for i in range(800):
                x_features, y_labels = load_data(sess)
                sess.run(step, feed_dict={x: x_features, y: y_labels})
                if (i + 1) % 50 == 0:
                    cur_loss, summary_ = sess.run([loss, summary_merged], feed_dict={x: x_features, y: y_labels})
                    writer.add_summary(summary_, i)
    
                    pred_ys = sess.run(y_, feed_dict={x: x_features, y: y_labels})
                    ys = tf.argmax(pred_ys, 0)
                    ys_correct = tf.argmax(y_labels, 0)
    
                    c = tf.equal(ys, ys_correct)
                    count = tf.reduce_sum(tf.cast(c, tf.float32))
    
                    r = sess.run(count)
                    print(i + 1, ': loss: ', cur_loss, '正确个数:', r)
    
            print('*************************')
            x_features, y_labels = load_data(sess, 150, should_display_images=True)
            pred_ys = sess.run(y_, feed_dict={x: x_features})
            ys = tf.argmax(pred_ys, 0)
            r = sess.run(ys)
            print('图片识别结果:', r)
            writer.close()
    
    
    if __name__ == '__main__':
        do_train()
    

      

    50 : loss:  7.3588676 正确个数: 4.0
    100 : loss:  6.6502814 正确个数: 5.0
    150 : loss:  5.26784 正确个数: 7.0
    200 : loss:  4.0591483 正确个数: 9.0
    250 : loss:  3.4379258 正确个数: 8.0
    300 : loss:  3.114149 正确个数: 8.0
    350 : loss:  2.0274947 正确个数: 9.0
    400 : loss:  1.4823446 正确个数: 10.0
    450 : loss:  1.4051719 正确个数: 10.0
    500 : loss:  0.91150457 正确个数: 10.0
    550 : loss:  0.7835213 正确个数: 10.0
    600 : loss:  0.72512466 正确个数: 10.0
    650 : loss:  0.56525075 正确个数: 10.0
    700 : loss:  0.4699742 正确个数: 10.0
    750 : loss:  0.45453963 正确个数: 10.0
    800 : loss:  0.45089394 正确个数: 10.0
    *************************
    图片识别结果: [0 1 2 3 4 5 6 7 8 9]

  • 相关阅读:
    女白领在家玩打地鼠游戏,无意间学会python编程,还有教程有源码
    Python入门小游戏,炫酷打地鼠教程第二部分,都是干货
    从python入门开始,玩这个炸弹超人小游戏,打通关就可以掌握编程
    Python如何入门?直接按这个方式玩炸弹超人小游戏,就能掌握编程
    Python如何入门?搭配这些游戏,学习高效还有趣
    资深程序员教你,利用python预测NBA比赛结果,太精彩了
    Python入门小迷宫,走完这个迷宫,就能掌握python编程基础
    从零基础开始,用python手把手教你玩跳一跳小游戏,直接打出高分
    戏精程序员,用python开发了一个女朋友,天天秀恩爱
    Python入门小游戏之坦克大战,不懂编程都能做出来,附所有源码
  • 原文地址:https://www.cnblogs.com/aarond/p/tf.html
Copyright © 2011-2022 走看看