zoukankan      html  css  js  c++  java
  • 【转载】 Tensorflow学习笔记-模型保存与加载

    版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
    本文链接:https://blog.csdn.net/lovelyaiq/article/details/78646401

    ————————————————

    保存模型时,文件格式有两种,ckpt和pb格式,这两种格式的模型区别是什么呢?首先看一下英文的解释。并且我们的学习中也要养成看英文文档的习惯,其一:老外写的东西通俗易懂,其二,在翻译时,每个人的英文理解不同,原汁原味的道理就没有了。

    The .ckpt is the model given by tensorflow which includes all the 
    weights/parameters in the model.  The .pb file stores the computational 
    graph.  To make tensorflow work we need both the graph and the 
    parameters.  There are two ways to get the graph: 
    (1) use the python program that builds it in the first place (tensorflowNetworkFunctions.py).
    (2) Use a .pb file (which would have to be generated by tensorflowNetworkFunctions.py). 
    .ckpt file is were all the intelligence is.

    使用Tensorflow训练好模型之后,我们需要将训练好的模型保存起来,方便以后的使用,这就是Tensorflow模型的持久化。

    保存

    Tensorflow的模型保存时有几点需要注意:
      1、利用tf.train.write_graph() 默认情况下只导出了网络的定义(没有权重weight)。
      2、利用tf.train.Saver().save() 导出的文件graph_def权重分离的,就像上述英文的描述。
      我们知道,graph_def文件中没有包含网络中的Variable值(通常情况存储了权重),但是却包含了constant值所以如果我们能把Variable转换为constant,即可达到使用一个文件同时存储网络架构与权重的目标
     

    import tensorflow as tf

    v1 = tf.Variable(tf.constant(1,shape = [1]),name='v1')
    v2 = tf.Variable(tf.constant(2,shape = [1]),name='v2')
    result = v1 + v2

    saver = tf.train.Saver()

    with tf.Session() as sess:
        tf.global_variables_initializer().run()
        print(sess.run(v1))
        print(sess.run(v2))
        print(sess.run(result))
        saver.save(sess,'model/model.ckpt')

    模型保存后,在model目录将会有三个文件。在Tensorflow版本0.11之前,这三个文件为:meta、ckpt、checkpopint,它们保存的内容如下:
        model.ckpt.meta保存计算图的结构,即神经网络的结构
        checkpoint保存一个目录下所有的模型文件列表。
        ckpt 保存程序中每一个变量的取值。
     

    在Tensorflow版本0.11之后,有四个文件分别为:meta、.data、.index、checkpoint。其中.data文件为模型中的训练变量。

    模型加载

      模型加载包含两种方式,它们的区分以是否含有计算图上的所有运算。

    包含所有运算

    import tensorflow as tf
    
    v1 = tf.Variable(tf.constant(1,shape = [1]),name='v1')
    v2 = tf.Variable(tf.constant(2,shape = [1]),name='v2')
    result = v1 + v2
    
    saver = tf.train.Saver()
    
    with tf.Session() as sess:
        saver.restore(sess,'model/model.ckpt')
        print(sess.run(v1+v2))

    这种方法加载模型时和保存模型时的代码基本上是一致的,唯一不同的就是没有变量的初始化过程

    模型加载的时候,如果某个变量没有被加载,则系统将会报错。我们可否使用已经定义好的其它变量来加载呢?当然是可以了,因为Tensorflow是支持的,这需要通过字典的形式来完成,将模型中的变量名重名为我们已经定好的其它变量名。

    import tensorflow as tf
    
    x = tf.Variable(tf.constant(1,shape = [1]),name='x')
    y = tf.Variable(tf.constant(2,shape = [1]),name='y')
    result = x + y
    
    # 通过字典将变量重命名
    saver = tf.train.Saver(
        {'v1':x,'v2':y})
    
    with tf.Session() as sess:
       saver.restore(sess,'model/model.ckpt')
       out = tf.get_default_graph().get_tensor_by_name('add:0')
       print(sess.run(out))

    使用变量的滑动平均值的模型保存与加载详见:http://blog.csdn.net/lovelyaiq/article/details/78647850

    不包含所有运算 

    import tensorflow as tf
    
    saver = tf.train.import_meta_graph('model/model.ckpt.meta')
    with tf.Session() as sess:
        saver.restore(sess,'model/model.ckpt')
    
        #获取节点名称
        result = tf.get_default_graph().get_tensor_by_name("add:0")
        print(sess.run(result))

    Saver类

      模型的加载与保存都使用到Saver类,该类的初始化参数为:

      def __init__(self,
                   var_list=None,
                   reshape=False,
                   sharded=False,
                   max_to_keep=5,
                   keep_checkpoint_every_n_hours=10000.0,
                   name=None,
                   restore_sequentially=False,
                   saver_def=None,
                   builder=None,
                   defer_build=False,
                   allow_empty=False,
                   write_version=saver_pb2.SaverDef.V2,
                   pad_step_number=False,
                   save_relative_paths=False,
                   filename=None):

    这里面主要用到的参数: 

        max_to_keep:保存checkpoint文件的最大数量,默认值为5.
        keep_checkpoint_every_n_hours:经过多长时间后,只保留一个checkpoint文件,这是方便验证模型训练多长时间后的性能。默认值为10000.0。
     

    而tf.train.save的参数为:

      def save(self,
               sess,
               save_path,
               global_step=None,
               latest_filename=None,
               meta_graph_suffix="meta",
               write_meta_graph=True,
               write_state=True):

      使用global_stepwrite_meta_graph两个参数可以很好的保存模型。

    saver.save(sess, 'my_test_model',global_step=1000)
    #保存的文件为:
    #my_test_model-1000.index
    #my_test_model-1000.meta
    #my_test_model-1000.data-00000-of-00001
    #checkpoint

    模型在保存的时候,计算图在第一次已经保存过,并且随着训练的进行,计算图是不会改变的,因此以后的保存,就可以使用write_meta_graph=True不保存计算图。

    saver.save(sess, 'my-model', global_step=step,write_meta_graph=False)

    tf.train.Saver()默认保存与加载计算图上所有信息。但有时我们只需要保存或加载部分信息。比如在测试或离线预测时,只需知道如何从神经网络的输入层经过前向传播到输出层即可,而不需要类似于变量的初始化、模型保存等辅助节点的信息。而且有时将变量的取值与计算图分开保存是不方便的,因此就需要借助  convert_variables_to_constants  将计算图上所有的变量及其取值通过常量保存,这样整个计算图将会保存到一个文件中。

    关于 convert_variables_to_constants 的源码定义如下:从解释中看出,当把网络完全转换为single GraphDef file,它可以删除与加载和保存变量相关的很多操作。

    def convert_variables_to_constants(sess, input_graph_def, output_node_names,variable_names_whitelist=None,variable_names_blacklist=None):
      """Replaces all the variables in a graph with constants of the same values.
    
      If you have a trained graph containing Variable ops, it can be convenient to convert them all to Const ops holding the same values. This makes it possible to describe the network fully with a single GraphDef file, and allows the removal of a lot of ops related to loading and saving the variables.
    import tensorflow as tf
    from tensorflow.python.framework import graph_util
    
    v1 = tf.Variable(tf.constant(1,shape = [1]),name='v1')
    v2 = tf.Variable(tf.constant(2,shape = [1]),name='v2')
    result = v1 + v2
    
    init_op = tf.global_variables_initializer()
    
    with tf.Session() as sess:
        sess.run(init_op)
    
        # 导出计算图的GraphDef部分,只需要这一部分就可以完成从输入层到输出层的计算过程。
        graph_def = tf.get_default_graph().as_graph_def()
    
        # print(graph_def)
    
        # 在这里我们只关心"add"节点,因此其它的节点就没有必要导出。
        output_graph_def = graph_util.convert_variables_to_constants(sess,graph_def,['add'])
    
        # 将导出的模型保存到本地
        with tf.gfile.GFile('model/combined_model.pb','wb') as f:
            f.write(output_graph_def.SerializeToString())

      导出模型的恢复:

    import tensorflow as tf
    from tensorflow.python.framework import graph_util
    
    v1 = tf.Variable(tf.constant(1,shape = [1]),name='v1')
    v2 = tf.Variable(tf.constant(2,shape = [1]),name='v2')
    result = v1 + v2
    
    init_op = tf.global_variables_initializer()
    
    with tf.Session() as sess:
        model_filename = 'model/combined_model.pb'
        with tf.gfile.FastGFile(model_filename,'rb') as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
            # 将graph_def保存的图加入到当前默认的图
        result = tf.import_graph_def(graph_def,return_elements=['add:0'])
        print(sess.run(result))

    上述方法有一个缺点,那就是我们不能自己定义一个网络输入的placeholder接口,这是不是很蛋筒,不要着急,Tensorflow是可以满足我们的需求。

    import tensorflow as tf
    from tensorflow.python.framework import graph_util
    import numpy as np
    
    v1 = tf.Variable(tf.constant(1,shape = [1]),name='v1')
    v2 = tf.Variable(tf.constant(2,shape = [1]),name='v2')
    result = v1 + v2
    
    
    with tf.variable_scope('foo'):
        x = tf.get_variable('x',shape=[1],initializer=tf.constant_initializer(1.0))
        y = tf.get_variable('y', shape=[1], initializer=tf.constant_initializer(2.0))
        # v1 = tf.Variable(tf.constant(1.0,shape=[1]),name='v1')
        # v2 = tf.Variable(tf.constant(2.0,shape=[1]),name='v2')
        input_tensor = tf.placeholder(tf.float32,shape=[1],name='input-x')
        new_tensor = tf.placeholder(tf.float32, shape=[1], name='input-y')
    
        result = tf.add((x+y),input_tensor,name='sum')
    
        data = np.array([15], dtype=np.float32)
    
        init_op = tf.global_variables_initializer()
    
    
        with tf.Session() as sess:
            sess.run(init_op)
            # print(sess.run(result,feed_dict={input_tensor:data}))
            # print(sess.run(result))
            graph_def = tf.get_default_graph().as_graph_def()
            # print(graph_def)
            output_graph_def = graph_util.convert_variables_to_constants(sess,graph_def,['foo/sum'])
            with tf.gfile.GFile('model/combined_model.pb','wb') as f:
                f.write(output_graph_def.SerializeToString())
    
    
        # 模型恢复
        with tf.Session() as sess:
            model_filename = 'model/combined_model.pb'
            with tf.gfile.FastGFile(model_filename,'rb') as f:
                graph_def = tf.GraphDef()
                graph_def.ParseFromString(f.read())
    
            # 使用input_map将模型中的placeholder通信映射到重新定义的placeholder。
            result1 = tf.import_graph_def(graph_def ,input_map={'foo/input-x:0':new_tensor},return_elements=['foo/sum:0'],name='')
    
            # [array([ 18.], dtype=float32)]
            print(sess.run(result1,feed_dict={new_tensor:data}))

    这种模型恢复的方法在迁移学习中是常用的方法,至于什么是迁移学习,请参考博客:

  • 相关阅读:
    localX mouseX stageX
    帮陈云庆做的手机报
    另一种换行排列方块的方法
    换行排列(思路源自陈勇源代码)
    网上摘的
    ASP.NET页面间数据传递(转)
    数据库连接字符串大全 之 SQL服务器篇
    保存一个免费的在线的图片转换工具网站,支持BMP,JPG,IOC,PNG和GIF
    关于IE6和IE7以及多个版本IE共存的问题
    如何测试sql语句性能,提高执行效率
  • 原文地址:https://www.cnblogs.com/devilmaycry812839668/p/12483096.html
Copyright © 2011-2022 走看看