zoukankan      html  css  js  c++  java
  • TensorFlow MNIST最佳实践

    之前通过CNN进行的MNIST训练识别成功率已经很高了,不过每次运行都需要消耗很多的时间。在实际使用的时候,每次都要选经过训练后在进行识别那就太不方便了。

    所以我们学习一下如何将训练习得的参数保存起来,然后在需要用的时候直接使用这些参数进行快速的识别。

    本章节代码来自《Tensorflow 实战Google深度学习框架》5.5 TensorFlow 最佳实践样例程序  针对书中的代码做了一点点的调整。

    mnist_inference.py:

    #coding=utf-8
    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data
    
    INPUT_NODE = 784
    OUTPUT_NODE = 10
    LAYER1_NODE = 500
    
    def get_weight_variable(shape, regularizer):
        weights = tf.get_variable("weights", shape, initializer = tf.truncated_normal_initializer(stddev=0.1))
        if regularizer != None:
            tf.add_to_collection('losses', regularizer(weights))
        return weights
    
    def inference(input_tensor, regularizer):
        with tf.variable_scope('layer1'):
            weights = get_weight_variable([INPUT_NODE, LAYER1_NODE], regularizer)
            biases = tf.get_variable("biases", [LAYER1_NODE], initializer=tf.constant_initializer(0.0))
            layer1 = tf.nn.relu(tf.matmul(input_tensor, weights) + biases)
    
        with tf.variable_scope('layer2'):
            weights = get_weight_variable([LAYER1_NODE, OUTPUT_NODE], regularizer)
            biases = tf.get_variable("biases", [OUTPUT_NODE], initializer=tf.constant_initializer(0.0))
            layer2 = tf.matmul(layer1, weights) + biases
    
        return layer2

    这里是向前传播的方法文件。这个方法在训练和测试的过程都需要用到,将它抽离出来既能使用起来更加方便,也能保证训练和测试时使用的方法保持一致。

    共享变量  tf.variable_scope & get_variable 方法:

    详细的使用方法和工作原理参见教程:共享变量

    get_variable

     weights = tf.get_variable("weights", shape, initializer = tf.truncated_normal_initializer(stddev=0.1))

    源代码第十行使用get_variable函数获取变量。

    在训练网络是会创建这些变量;

    在测试时会通过训练时保存的模型加载这些变量的值。

    mnist_train.py:

    #coding=utf-8
    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data
    
    import mnist_inference
    
    BATCH_SIZE = 100
    LEARNING_RATE_BASE = 0.8
    LEARNING_RATE_DECAY = 0.99
    REGULARAZTION_RATE = 0.0001
    TRAINING_STEPS = 30000
    MOVING_AVERAGE_DECAY = 0.99
    
    def train(mnist):
        x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], name="x-input")
        y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name="y-input")
    
        regularizer = tf.contrib.layers.l2_regularizer(REGULARAZTION_RATE)
        y = mnist_inference.inference(x, regularizer)
    
        global_step = tf.Variable(0, trainable=False)
    
        variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
        variables_averages_op = variable_averages.apply(tf.trainable_variables())
    cross_entropy
    = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1)) cross_entropy_mean = tf.reduce_mean(cross_entropy) loss = cross_entropy_mean + tf.add_n(tf.get_collection('losses')) learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE, global_step, mnist.train.num_examples / BATCH_SIZE, LEARNING_RATE_DECAY) train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step) with tf.control_dependencies([train_step, variables_averages_op]): train_op = tf.no_op(name='train') saver = tf.train.Saver() with tf.Session() as sess: init = tf.global_variables_initializer() sess.run(init) for i in range(TRAINING_STEPS): xs, ys = mnist.train.next_batch(BATCH_SIZE) _, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: xs, y_: ys}) if i % 1000 == 0: print("After %d training step(s), loss on training batch is %g." % (step, loss_value)) saver.save(sess, "./mnist_variables/trained_variables.ckpt",global_step=global_step) def main(argv=None): mnist = input_data.read_data_sets('MNIST_data', one_hot=True) train(mnist) if __name__ == '__main__': tf.app.run()

     使用正则函数防止过拟合:

    regularizer = tf.contrib.layers.l2_regularizer(REGULARAZTION_RATE)

     第25行和第29行代码:

    global_step = tf.Variable(0, trainable=False)
    。。。。。。。

    variables_averages_op = variable_averages.apply(tf.trainable_variables())

    tf.trinable_variables()不会获取到global_step 因为trainable设置为了False。

    如果tf.trainable_variables() 换成 tf.all_variables()就能获取到global_step了。

    mnist_eval.py:

    #coding=utf-8
    import time
    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data
    
    import mnist_inference
    import mnist_train
    
    EVAL_INTERVAL_SECS = 10
    
    def evaluate(mnist):
        with tf.Graph().as_default() as g:
            x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], name="x-input")
            y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name="y-input")
            validate_feed = {x:mnist.test.images,y_:mnist.test.labels}
    
            y = mnist_inference.inference(x, None)
            correct_prediction = tf.equal(tf.argmax(y, 1),tf.argmax(y_, 1))
            accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    
            variable_averages = tf.train.ExponentialMovingAverage(mnist_train.MOVING_AVERAGE_DECAY)
            variables_to_restore = variable_averages.variables_to_restore()
            saver = tf.train.Saver(variables_to_restore)
    
            while True:
                with tf.Session() as sess:
                    ckpt = tf.train.get_checkpoint_state("./mnist_variables/")
                    if ckpt and ckpt.model_checkpoint_path:
                        saver.restore(sess, ckpt.model_checkpoint_path)
                        global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
                        accuracy_score = sess.run(accuracy, feed_dict=validate_feed)
                        print("After %s training step(s), validation accuracy = %g " % (global_step, accuracy_score))
                    else:
                         print('No checkpoint file found')
                         return
                    time.sleep(EVAL_INTERVAL_SECS)
    
    def main(argv=None):
        mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
        evaluate(mnist)
    
    if __name__ == '__main__':
        tf.app.run()

    (未完待续。。。。)

  • 相关阅读:
    MVP模式与MVVM模式
    webpack的配置处理
    leetcode 287 Find the Duplicate Number
    leetcode 152 Maximum Product Subarray
    leetcode 76 Minimum Window Substring
    感知器算法初探
    leetcode 179 Largest Number
    leetcode 33 Search in Rotated Sorted Array
    leetcode 334 Increasing Triplet Subsequence
    朴素贝叶斯分类器初探
  • 原文地址:https://www.cnblogs.com/guolaomao/p/8028600.html
Copyright © 2011-2022 走看看