zoukankan      html  css  js  c++  java
  • TensorFlow 之 手写数字识别MNIST

    官方文档: 
    MNIST For ML Beginners - https://www.tensorflow.org/get_started/mnist/beginners 
    Deep MNIST for Experts - https://www.tensorflow.org/get_started/mnist/pros 

    版本: 
    TensorFlow 1.2.0 + Flask 0.12 + Gunicorn 19.6 

    相关文章: 
    TensorFlow 之 入门体验 
    TensorFlow 之 手写数字识别MNIST 
    TensorFlow 之 物体检测 
    TensorFlow 之 构建人物识别系统 

    MNIST相当于机器学习界的Hello World。 

    这里在页面通过 Canvas 画一个数字,然后传给TensorFlow识别,分别给出Softmax回归模型、多层卷积网络的识别结果。 

    (1)文件结构 

    │  main.py 
    │  requirements.txt 
    │  runtime.txt 
    ├─mnist 
    │  │  convolutional.py 
    │  │  model.py 
    │  │  regression.py 
    │  │  __init__.py 
    │  └─data 
    │          convolutional.ckpt.data-00000-of-00001 
    │          convolutional.ckpt.index 
    │          regression.ckpt.data-00000-of-00001 
    │          regression.ckpt.index 
    ├─src 
    │  └─js 
    │          main.js 
    ├─static 
    │  ├─css 
    │  │      bootstrap.min.css 
    │  └─js 
    │          jquery.min.js 
    │          main.js 
    └─templates 
            index.html 

    (2)训练数据 

    下载以下文件放入/tmp/data/,不用解压,训练代码会自动解压。 
    引用
    http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz 
    http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz 
    http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz 
    http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz


    执行命令训练数据(Softmax回归模型、多层卷积网络) 
    Shell代码  收藏代码
    1. # python regression.py  
    2. # python convolutional.py  


    执行完成后 在 mnist/data/ 里会生成以下几个文件,重新训练前需要把这几个文件先删掉。 
    引用
    convolutional.ckpt.data-00000-of-00001 
    convolutional.ckpt.index 
    regression.ckpt.data-00000-of-00001 
    regression.ckpt.index


    (3)启动Web服务测试 

    Shell代码  收藏代码
    1. # cd /usr/local/tensorflow2/tensorflow-models/tf-mnist  
    2. # pip install -r requirements.txt  
    3. # gunicorn main:app --log-file=- --bind=localhost:8000  


    浏览器中访问:http://localhost:8000 

    *** 运行的TensorFlow版本、数据训练的模型、还有这里Canvas的转换都对识别率有一定的影响~! 

    (4)源代码 

    Web部分比较简单,页面上放置一个Canvas,鼠标抬起时将Canvas的图像通过Ajax传给后台API,然后显示API结果。 
    引用
    src/js/main.js -> static/js/main.js 
    templates/index.html


    main.py 
    Python代码  收藏代码
    1. import numpy as np  
    2. import tensorflow as tf  
    3. from flask import Flask, jsonify, render_template, request  
    4.   
    5. from mnist import model  
    6.   
    7. x = tf.placeholder("float", [None, 784])  
    8. sess = tf.Session()  
    9.   
    10. # restore trained data  
    11. with tf.variable_scope("regression"):  
    12.     y1, variables = model.regression(x)  
    13. saver = tf.train.Saver(variables)  
    14. saver.restore(sess, "mnist/data/regression.ckpt")  
    15.   
    16. with tf.variable_scope("convolutional"):  
    17.     keep_prob = tf.placeholder("float")  
    18.     y2, variables = model.convolutional(x, keep_prob)  
    19. saver = tf.train.Saver(variables)  
    20. saver.restore(sess, "mnist/data/convolutional.ckpt")  
    21.   
    22. def regression(input):  
    23.     return sess.run(y1, feed_dict={x: input}).flatten().tolist()  
    24.   
    25. def convolutional(input):  
    26.     return sess.run(y2, feed_dict={x: input, keep_prob: 1.0}).flatten().tolist()  
    27.   
    28. # webapp  
    29. app = Flask(__name__)  
    30.  
    31. @app.route('/api/mnist', methods=['POST'])  
    32. def mnist():  
    33.     input = ((255 - np.array(request.json, dtype=np.uint8)) / 255.0).reshape(1, 784)  
    34.     output1 = regression(input)  
    35.     output2 = convolutional(input)  
    36.     print(output1)  
    37.     print(output2)  
    38.     return jsonify(results=[output1, output2])  
    39.  
    40. @app.route('/')  
    41. def main():  
    42.     return render_template('index.html')  
    43.   
    44. if __name__ == '__main__':  
    45.     app.run()  


    mnist/model.py 
    Python代码  收藏代码
    1. import tensorflow as tf  
    2.   
    3.   
    4. # Softmax Regression Model  
    5. def regression(x):  
    6.     W = tf.Variable(tf.zeros([784, 10]), name="W")  
    7.     b = tf.Variable(tf.zeros([10]), name="b")  
    8.     y = tf.nn.softmax(tf.matmul(x, W) + b)  
    9.     return y, [W, b]  
    10.   
    11.   
    12. # Multilayer Convolutional Network  
    13. def convolutional(x, keep_prob):  
    14.     def conv2d(x, W):  
    15.         return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')  
    16.   
    17.     def max_pool_2x2(x):  
    18.         return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')  
    19.   
    20.     def weight_variable(shape):  
    21.         initial = tf.truncated_normal(shape, stddev=0.1)  
    22.         return tf.Variable(initial)  
    23.   
    24.     def bias_variable(shape):  
    25.         initial = tf.constant(0.1, shape=shape)  
    26.         return tf.Variable(initial)  
    27.   
    28.     # First Convolutional Layer  
    29.     x_image = tf.reshape(x, [-1, 28, 28, 1])  
    30.     W_conv1 = weight_variable([5, 5, 1, 32])  
    31.     b_conv1 = bias_variable([32])  
    32.     h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)  
    33.     h_pool1 = max_pool_2x2(h_conv1)  
    34.     # Second Convolutional Layer  
    35.     W_conv2 = weight_variable([5, 5, 32, 64])  
    36.     b_conv2 = bias_variable([64])  
    37.     h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)  
    38.     h_pool2 = max_pool_2x2(h_conv2)  
    39.     # Densely Connected Layer  
    40.     W_fc1 = weight_variable([7 * 7 * 64, 1024])  
    41.     b_fc1 = bias_variable([1024])  
    42.     h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])  
    43.     h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)  
    44.     # Dropout  
    45.     h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)  
    46.     # Readout Layer  
    47.     W_fc2 = weight_variable([1024, 10])  
    48.     b_fc2 = bias_variable([10])  
    49.     y = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)  
    50.     return y, [W_conv1, b_conv1, W_conv2, b_conv2, W_fc1, b_fc1, W_fc2, b_fc2]  


    mnist/convolutional.py 
    Python代码  收藏代码
    1. import os  
    2. import model  
    3. import tensorflow as tf  
    4.   
    5. from tensorflow.examples.tutorials.mnist import input_data  
    6. data = input_data.read_data_sets("/tmp/data/", one_hot=True)  
    7.   
    8. # model  
    9. with tf.variable_scope("convolutional"):  
    10.     x = tf.placeholder(tf.float32, [None, 784])  
    11.     keep_prob = tf.placeholder(tf.float32)  
    12.     y, variables = model.convolutional(x, keep_prob)  
    13.   
    14. # train  
    15. y_ = tf.placeholder(tf.float32, [None, 10])  
    16. cross_entropy = -tf.reduce_sum(y_ * tf.log(y))  
    17. train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)  
    18. correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))  
    19. accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))  
    20.   
    21. saver = tf.train.Saver(variables)  
    22. with tf.Session() as sess:  
    23.     sess.run(tf.global_variables_initializer())  
    24.     for i in range(20000):  
    25.         batch = data.train.next_batch(50)  
    26.         if i % 100 == 0:  
    27.             train_accuracy = accuracy.eval(feed_dict={x: batch[0], y_: batch[1], keep_prob: 1.0})  
    28.             print("step %d, training accuracy %g" % (i, train_accuracy))  
    29.         sess.run(train_step, feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})  
    30.   
    31.     print(sess.run(accuracy, feed_dict={x: data.test.images, y_: data.test.labels, keep_prob: 1.0}))  
    32.   
    33.     path = saver.save(  
    34.         sess, os.path.join(os.path.dirname(__file__), 'data', 'convolutional.ckpt'),  
    35.         write_meta_graph=False, write_state=False)  
    36.     print("Saved:", path)  


    mnist/regression.py 
    Python代码  收藏代码
    1. import os  
    2. import model  
    3. import tensorflow as tf  
    4.   
    5. from tensorflow.examples.tutorials.mnist import input_data  
    6. data = input_data.read_data_sets("/tmp/data/", one_hot=True)  
    7.   
    8. # model  
    9. with tf.variable_scope("regression"):  
    10.     x = tf.placeholder(tf.float32, [None, 784])  
    11.     y, variables = model.regression(x)  
    12.   
    13. # train  
    14. y_ = tf.placeholder("float", [None, 10])  
    15. cross_entropy = -tf.reduce_sum(y_ * tf.log(y))  
    16. train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)  
    17. correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))  
    18. accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))  
    19.   
    20. saver = tf.train.Saver(variables)  
    21. with tf.Session() as sess:  
    22.     sess.run(tf.global_variables_initializer())  
    23.     for _ in range(1000):  
    24.         batch_xs, batch_ys = data.train.next_batch(100)  
    25.         sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})  
    26.   
    27.     print(sess.run(accuracy, feed_dict={x: data.test.images, y_: data.test.labels}))  
    28.   
    29.     path = saver.save(  
    30.         sess, os.path.join(os.path.dirname(__file__), 'data', 'regression.ckpt'),  
    31.         write_meta_graph=False, write_state=False)  
    32.     print("Saved:", path)  


    参考: 
    http://memo.sugyan.com/entry/20151124/1448292129
  • 相关阅读:
    解决Android Studio Gradle DSL method not found: 'android()'
    【转】关于ListView中notifyDataSetChanged()刷新数据不更新原因
    设计模式-单例模式
    IE浏览器让DIV居中
    Java通过DOM解析XML
    git 配置文件位置;git配置文件设置
    git config配置
    dos2unix
    文件的编码问题解决
    git diff old mode 100644 new mode 100755
  • 原文地址:https://www.cnblogs.com/Ph-one/p/9074706.html
Copyright © 2011-2022 走看看