zoukankan      html  css  js  c++  java
  • 第五章 MNIST数字识别问题(二)

    5.4 TensorFlow模型持久化

    5.4.1. ckpt文件保存方法

    在对模型进行加载时候,需要定义出与原来的计算图结构完全相同的计算图,然后才能进行加载,并且不需要对定义出来的计算图进行初始化操作。 
    这样保存下来的模型,会在其文件夹下生成三个文件,分别是: 
    * .ckpt.meta文件,保存tensorflow模型的计算图结构。 
    * .ckpt文件,保存计算图下所有变量的取值。 
    * checkpoint文件,保存目录下所有模型文件列表。

    import tensorflow as tf
    #保存计算两个变量和的模型
    v1 = tf.Variable(tf.random_normal([1], stddev=1, seed=1))
    v2 = tf.Variable(tf.random_normal([1], stddev=1, seed=1))
    result = v1 + v2
    
    init_op = tf.global_variables_initializer()
    saver = tf.train.Saver()
    
    with tf.Session() as sess:
        sess.run(init_op)
        saver.save(sess, "Saved_model/model.ckpt")
    #加载保存了两个变量和的模型
    with tf.Session() as sess:
        saver.restore(sess, "Saved_model/model.ckpt")
        print sess.run(result)
    
    INFO:tensorflow:Restoring parameters from Saved_model/model.ckpt
    [-1.6226364]
    #直接加载持久化的图。因为之前没有导出v3,所以这里会报错
    saver = tf.train.import_meta_graph("Saved_model/model.ckpt.meta")
    v3 = tf.Variable(tf.random_normal([1], stddev=1, seed=1))
    
    with tf.Session() as sess:
        saver.restore(sess, "Saved_model/model.ckpt")
        print sess.run(v1) 
        print sess.run(v2) 
        print sess.run(v3) 
    INFO:tensorflow:Restoring parameters from Saved_model/model.ckpt
    [-0.81131822]
    [-0.81131822]
    
    # 变量重命名,这样可以通过字典将模型保存时的变量名和需要加载的变量联系起来
    v1 = tf.Variable(tf.constant(1.0, shape=[1]), name = "other-v1")
    v2 = tf.Variable(tf.constant(2.0, shape=[1]), name = "other-v2")
    saver = tf.train.Saver({"v1": v1, "v2": v2})
    View Code

    5.4.2.1 滑动平均类的保存

    import tensorflow as tf
    #使用滑动平均
    v = tf.Variable(0, dtype=tf.float32, name="v")
    for variables in tf.global_variables(): print variables.name
        
    ema = tf.train.ExponentialMovingAverage(0.99)
    maintain_averages_op = ema.apply(tf.global_variables())
    for variables in tf.global_variables(): print variables.name
    v:0
    v:0
    v/ExponentialMovingAverage:0
    
    #保存滑动平均模型
    saver = tf.train.Saver()
    with tf.Session() as sess:
        init_op = tf.global_variables_initializer()
        sess.run(init_op)
        
        sess.run(tf.assign(v, 10))
        sess.run(maintain_averages_op)
        # 保存的时候会将v:0  v/ExponentialMovingAverage:0这两个变量都存下来。
        saver.save(sess, "Saved_model/model2.ckpt")
        print sess.run([v, ema.average(v)])
    10.0, 0.099999905]
    
    #加载滑动平均模型
    v = tf.Variable(0, dtype=tf.float32, name="v")
    
    # 通过变量重命名将原来变量v的滑动平均值直接赋值给v。
    saver = tf.train.Saver({"v/ExponentialMovingAverage": v})
    with tf.Session() as sess:
        saver.restore(sess, "Saved_model/model2.ckpt")
        print sess.run(v)
    INFO:tensorflow:Restoring parameters from Saved_model/model2.ckpt
    0.0999999
    View Code

    5.4.2.2 variables_to_restore函数的使用样例

    import tensorflow as tf
    v = tf.Variable(0, dtype=tf.float32, name="v")
    ema = tf.train.ExponentialMovingAverage(0.99)
    print ema.variables_to_restore()
    
    #等同于saver = tf.train.Saver(ema.variables_to_restore())
    saver = tf.train.Saver({"v/ExponentialMovingAverage": v})
    with tf.Session() as sess:
        saver.restore(sess, "Saved_model/model2.ckpt")
        print sess.run(v)
    {u'v/ExponentialMovingAverage': <tf.Variable 'v:0' shape=() dtype=float32_ref>}

    5.4.3. pb文件保存方法

    #pb文件的保存方法
    import tensorflow as tf
    from tensorflow.python.framework import graph_util
    
    v1 = tf.Variable(tf.constant(1.0, shape=[1]), name = "v1")
    v2 = tf.Variable(tf.constant(2.0, shape=[1]), name = "v2")
    result = v1 + v2
    
    init_op = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init_op)
        graph_def = tf.get_default_graph().as_graph_def()
        output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, ['add'])
        with tf.gfile.GFile("Saved_model/combined_model.pb", "wb") as f:
               f.write(output_graph_def.SerializeToString())
    
    INFO:tensorflow:Froze 2 variables.
    Converted 2 variables to const ops.
    ------------------------------------------------------------------------
    #加载pb文件
    from tensorflow.python.platform import gfile
    with tf.Session() as sess:
        model_filename = "Saved_model/combined_model.pb"
       
        with gfile.FastGFile(model_filename, 'rb') as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
    
        result = tf.import_graph_def(graph_def, return_elements=["add:0"])
        print sess.run(result)
    
    [array([ 3.], dtype=float32)]

    张量的名称后面有:0,表示是某个计算节点的第一个输出,而计算节点本身的名称后是没有:0的。

    5.5 TensorFlow最佳实践样例程序

    为了使程序的可扩展性更好,减少编写冗余代码,提高编程效率,我们可以将不同功能模块分开,这一节还会将前向传播过程抽象成一个单独库函数。

    mnist_inference

    import tensorflow as tf
    #1. 定义神经网络结构相关的参数
    INPUT_NODE = 784
    OUTPUT_NODE = 10
    LAYER1_NODE = 500
    #2. 通过tf.get_variable函数来获取变量
    def get_weight_variable(shape, regularizer):
        weights = tf.get_variable("weights", shape, initializer=tf.truncated_normal_initializer(stddev=0.1))
        if regularizer != None: tf.add_to_collection('losses', regularizer(weights))
        return weights
    #3. 定义神经网络的前向传播过程
    def inference(input_tensor, regularizer):
        with tf.variable_scope('layer1'):
    
            weights = get_weight_variable([INPUT_NODE, LAYER1_NODE], regularizer)
            biases = tf.get_variable("biases", [LAYER1_NODE], initializer=tf.constant_initializer(0.0))
            layer1 = tf.nn.relu(tf.matmul(input_tensor, weights) + biases)
    
    
        with tf.variable_scope('layer2'):
            weights = get_weight_variable([LAYER1_NODE, OUTPUT_NODE], regularizer)
            biases = tf.get_variable("biases", [OUTPUT_NODE], initializer=tf.constant_initializer(0.0))
            layer2 = tf.matmul(layer1, weights) + biases
    
        return layer2

    mnist_train

    mnist_train
    #1. 定义神经网络结构相关的参数
    BATCH_SIZE = 100 
    LEARNING_RATE_BASE = 0.8
    LEARNING_RATE_DECAY = 0.99
    REGULARIZATION_RATE = 0.0001
    TRAINING_STEPS = 30000
    MOVING_AVERAGE_DECAY = 0.99 
    MODEL_SAVE_PATH = "MNIST_model/"
    MODEL_NAME = "mnist_model"
    #2. 定义训练过程
    def train(mnist):
        # 定义输入输出placeholder。
        x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], name='x-input')
        y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name='y-input')
    
        regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)
        y = mnist_inference.inference(x, regularizer)
        global_step = tf.Variable(0, trainable=False)
        
        # 定义损失函数、学习率、滑动平均操作以及训练过程。
        variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
        variables_averages_op = variable_averages.apply(tf.trainable_variables())
        cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))
        cross_entropy_mean = tf.reduce_mean(cross_entropy)
        loss = cross_entropy_mean + tf.add_n(tf.get_collection('losses'))
        learning_rate = tf.train.exponential_decay(
            LEARNING_RATE_BASE,
            global_step,
            mnist.train.num_examples / BATCH_SIZE, LEARNING_RATE_DECAY,
            staircase=True)
        train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)
        with tf.control_dependencies([train_step, variables_averages_op]):
            train_op = tf.no_op(name='train')
            
        # 初始化TensorFlow持久化类。
        saver = tf.train.Saver()
        with tf.Session() as sess:
            tf.global_variables_initializer().run()
    
            for i in range(TRAINING_STEPS):
                xs, ys = mnist.train.next_batch(BATCH_SIZE)
                _, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: xs, y_: ys})
                if i % 1000 == 0:
                    print("After %d training step(s), loss on training batch is %g." % (step, loss_value))
                    saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step)
    #3. 主程序入口
    def main(argv=None):
        mnist = input_data.read_data_sets("../../../datasets/MNIST_data", one_hot=True)
        train(mnist)
    
    if __name__ == '__main__':
        main()
    -------------------------------------------------------------------------
    Extracting ../../../datasets/MNIST_data/train-images-idx3-ubyte.gz
    Extracting ../../../datasets/MNIST_data/train-labels-idx1-ubyte.gz
    Extracting ../../../datasets/MNIST_data/t10k-images-idx3-ubyte.gz
    Extracting ../../../datasets/MNIST_data/t10k-labels-idx1-ubyte.gz
    After 1 training step(s), loss on training batch is 3.05851.
    After 1001 training step(s), loss on training batch is 0.207949.
    After 2001 training step(s), loss on training batch is 0.214515.
    After 3001 training step(s), loss on training batch is 0.237391.
    After 4001 training step(s), loss on training batch is 0.115064.
    After 5001 training step(s), loss on training batch is 0.103093.
    After 6001 training step(s), loss on training batch is 0.133556.
    ....

    mnist_eval

    import time
    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data
    import mnist_inference
    import mnist_train
    #1. 每10秒加载一次最新的模型
    # 加载的时间间隔。
    EVAL_INTERVAL_SECS = 10
    
    def evaluate(mnist):
        with tf.Graph().as_default() as g:
            x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], name='x-input')
            y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name='y-input')
            validate_feed = {x: mnist.validation.images, y_: mnist.validation.labels}
    
            y = mnist_inference.inference(x, None)
            correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
            accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    
            variable_averages = tf.train.ExponentialMovingAverage(mnist_train.MOVING_AVERAGE_DECAY)
            variables_to_restore = variable_averages.variables_to_restore()
            saver = tf.train.Saver(variables_to_restore)
    
            while True:
                with tf.Session() as sess:
                    ckpt = tf.train.get_checkpoint_state(mnist_train.MODEL_SAVE_PATH)
                    if ckpt and ckpt.model_checkpoint_path:
                        saver.restore(sess, ckpt.model_checkpoint_path)
                        global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
                        accuracy_score = sess.run(accuracy, feed_dict=validate_feed)
                        print("After %s training step(s), validation accuracy = %g" % (global_step, accuracy_score))
                    else:
                        print('No checkpoint file found')
                        return
                time.sleep(EVAL_INTERVAL_SECS)
    #2主程序
    def main(argv=None):
        mnist = input_data.read_data_sets("../../../datasets/MNIST_data", one_hot=True)
        evaluate(mnist)
    
    if __name__ == '__main__':
        main()
    ------------------------------------------------------------------------
    Extracting ../../../datasets/MNIST_data/train-images-idx3-ubyte.gz
    Extracting ../../../datasets/MNIST_data/train-labels-idx1-ubyte.gz
    Extracting ../../../datasets/MNIST_data/t10k-images-idx3-ubyte.gz
    Extracting ../../../datasets/MNIST_data/t10k-labels-idx1-ubyte.gz
    INFO:tensorflow:Restoring parameters from MNIST_model/mnist_model-4001
    After 4001 training step(s), validation accuracy = 0.9826
    INFO:tensorflow:Restoring parameters from MNIST_model/mnist_model-5001
    After 5001 training step(s), validation accuracy = 0.983
    INFO:tensorflow:Restoring parameters from MNIST_model/mnist_model-6001
    After 6001 training step(s), validation accuracy = 0.9832
    INFO:tensorflow:Restoring parameters from MNIST_model/mnist_model-7001
    After 7001 training step(s), validation accuracy = 0.9834...
  • 相关阅读:
    PetShop4.0学习笔记[01]:订单处理[01]——多线程程序结构
    PetShop4.0学习笔记——使用命名空间
    PetShop 4.0学习笔记:消息队列MSMQ
    petshop4 MSDTC 不可用解决
    经典工具软件备份
    ASP.NET知识点(三):购物车与收藏蓝的实现[Profile]
    PetShop 4.0知识点:加密和解密Web.config文件的配置节
    PetShop 4.0知识点:base 关键字用于从派生类中访问基类的成员
    从Word,Excel中提取Flash
    线性结构
  • 原文地址:https://www.cnblogs.com/exciting/p/8542859.html
Copyright © 2011-2022 走看看