zoukankan      html  css  js  c++  java
  • TebsorFlow低阶API(五)—— 保存和恢复

    简介

    tf.train.Saver 类提供了保存和恢复模型的方法。通过 tf.saved_model.simple_save 函数可以轻松地保存适合投入使用的模型。Estimator会自动保存和恢复 model_dir 中的变量。

    保存和恢复变量

    TensorFlow变量是表示由程序操作的共享持久状态的最佳方法。tf.train.Saver 构造函数会针对图中的所有变量或指定列表的变量将 save 和 restore 操作添加到图中。Saver对象提供了运行这些操作的方法,并指定写入或读取检查点文件的路径。

    Saver 会恢复已经在模型中定义的所有变量。如果您在不知道如何构件图的情况下加载模型(例如,您要编写用于加载各类模型的通用程序),那么请阅读本文档后面的保存和恢复模型概述部分。

    TensorFlow将变量保存在二进制检查点文件中,这类文件会将变量名称映射到张量值。

    注意:TensorFlow 模型文件是代码。请注意不可信的代码。详情请参阅安全地使用 TensorFlow

    保存变量

    创建Saver(使用 tf.train.Saver())来管理模型中的所有变量。例如,以下代码展示了如何调用 tf.train.Saver.save 方法以将变量保存到检查点文件中:

     1 # Create some variables.
     2 v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer)
     3 v2 = tf.get_variable("v2", shape=[5], initializer = tf.zeros_initializer)
     4 
     5 inc_v1 = v1.assign(v1+1)
     6 dec_v2 = v2.assign(v2-1)
     7 
     8 # Add an op to initialize the variables.
     9 init_op = tf.global_variables_initializer()
    10 
    11 # Add ops to save and restore all the variables.
    12 saver = tf.train.Saver()
    13 
    14 # Later, launch the model, initialize the variables, do some work, and save the variables to disk.
    15 with tf.Session() as sess:
    16   sess.run(init_op)
    17   # Do some work with the model.
    18   inc_v1.op.run()
    19   dec_v2.op.run()
    20   # Save the variables to disk.
    21   save_path = saver.save(sess, "/tmp/model.ckpt")
    22   print("Model saved in path: %s" % save_path)

    恢复变量

    tf.train.Saver 对象不仅将变量保存在检查点文件中,还将恢复变量。请注意,当您恢复变量时,您不必事先将其初始化。例如,以下代码段展示了如何调用 tf.train.Saver.restore 方法以从检查点文件中恢复变量:

     1 tf.reset_default_graph()
     2 
     3 # Create some variables.
     4 v1 = tf.get_variable("v1", shape=[3])
     5 v2 = tf.get_variable("v2", shape=[5])
     6 
     7 # Add ops to save and restore all the variables.
     8 saver = tf.train.Saver()
     9 
    10 # Later, launch the model, use the saver to restore variables from disk, and
    11 # do some work with the model.
    12 with tf.Session() as sess:
    13   # Restore variables from disk.
    14   saver.restore(sess, "/tmp/model.ckpt")
    15   print("Model restored.")
    16   # Check the values of the variables
    17   print("v1 : %s" % v1.eval())
    18   print("v2 : %s" % v2.eval())

    注意:并没有名为 /tmp/model.ckpt 的实体文件。它是为检查点创建的文件名的前缀。用户仅与前缀(而非检查点实体文件)互动。

    选择要保存和恢复的变量

    如果您没有向 tf.train.Saver()传递任何参数,则Saver会处理图中的所有变量。每个变量都保存在创建变量时所传递的名称下。

    在检查点文件中明确指定变量名称的这种做法有时非常有用。例如,您可能已经使用名为“weights” 的变量训练了一个模型,而您想要将该变量的值恢复到名为“params”的变量中。

    有时候,仅保存和恢复模型使用的变量子集也会很有裨益。例如,您可能已经训练了一个五层的神经网络,现在您训练一个六层的新模型,并重用该五层的现有权重。您可以使用Saver只恢复这前五层的权重。

    您可以通过向 tf.train.Saver() 构造函数传递以下任一内容,轻松指定要保存或加载的名称和变量:

    • 变量列表(将以其本身的名称保存)。
    • Python字典,其中,键是要使用的名称,键值是要管理的变量。

    继续前面所示的保存/恢复示例:

     1 tf.reset_default_graph()
     2 # Create some variables.
     3 v1 = tf.get_variable("v1", [3], initializer = tf.zeros_initializer)
     4 v2 = tf.get_variable("v2", [5], initializer = tf.zeros_initializer)
     5 
     6 # Add ops to save and restore only `v2` using the name "v2"
     7 saver = tf.train.Saver({"v2": v2})
     8 
     9 # Use the saver object normally after that.
    10 with tf.Session() as sess:
    11   # Initialize v1 since the saver will not.
    12   v1.initializer.run()
    13   saver.restore(sess, "/tmp/model.ckpt")
    14 
    15   print("v1 : %s" % v1.eval())
    16   print("v2 : %s" % v2.eval())

    注意:

    • 如果要保存和恢复模型变量的不同子集,您可以根据需要创建任意数量的 Saver对象。同一个变量可以列在多个Saver对象中,变量的值只有在Saver.restore()方法运行时才会更改。
    • 如果您在会话开始时仅恢复一部分模型变量,则必须为其它变量运行初始化操作。
    • 要检查某个检查点中的变量,您可以使用 inspect_checkpoint 库,尤其是 print_tensors_in_checkpoint_file 函数。默认情况下,Saver会针对每个变量使用 tf.Variable.name 属性的值。但是,当您创建Saver对象时,您可以选择为检查点中的变量选择名称。

    检查某个检查点中的变量

    我们可以使用 inspect_checkpoint 库快速检查某个检查点中的变量。

    继续前面所示的保存/恢复示例:

     1 # import the inspect_checkpoint library
     2 from tensorflow.python.tools import inspect_checkpoint as chkp
     3 
     4 # print all tensors in checkpoint file
     5 chkp.print_tensors_in_checkpoint_file("/tmp/model.ckpt", tensor_name='', all_tensors=True)
     6 
     7 # tensor_name:  v1
     8 # [ 1.  1.  1.]
     9 # tensor_name:  v2
    10 # [-1. -1. -1. -1. -1.]
    11 
    12 # print only tensor v1 in checkpoint file
    13 chkp.print_tensors_in_checkpoint_file("/tmp/model.ckpt", tensor_name='v1', all_tensors=False)
    14 
    15 # tensor_name:  v1
    16 # [ 1.  1.  1.]
    17 
    18 # print only tensor v2 in checkpoint file
    19 chkp.print_tensors_in_checkpoint_file("/tmp/model.ckpt", tensor_name='v2', all_tensors=False)
    20 
    21 # tensor_name:  v2
    22 # [-1. -1. -1. -1. -1.]

    保存和恢复模型

    使用 SavedModel 保存和加载模型-变量、图和图的元数据。SavedModel是一种独立于语言且可恢复的神秘序列化格式,使较高级别的系统和工具可以创建、使用和转换TensorFlow模型。TensorFlow提供了多种与 SaverModel交互的方式,包括 tf.saved_model  API、tf.estimator.Estimator和命令行界面。

    构建和加载SavedModel

    简单保存

    创建SavedModel 的最简单的方法使使用 tf.saved_model.simple_save 函数:

    1 simple_save(session,
    2             export_dir,
    3             inputs={"x": x, "y": y},
    4             outputs={"z": z})

    这样可以配置 SavedModel,使其能够通过 TensorFlow  Serving进行加载,并支持Predict  API。要访问classify API、regress API或者multi-inference API,请使用手动SavedModel builder API或 tf.estimator.Estimator。

    手动构建按SavedModel

    如果您的用例不在 tf.saved_model.simple_save涵盖范围内,请手动 builder API 创建SaverModel。

    tf.saved_model.builder.SavedModelBuilder 类提供了保存多个 MetaGraphDef 的功能。MetaGraph是一种数据流图,并包含相关变量、资源和签名。MetaGraphDef是MetaGraph的协议缓冲区表示法。签名是一组与图有关的输入和输出。

    如果需要将资源保存并写入或复制到磁盘,则可以在首次添加 MetaGraphDef时提供这些资源。如果多个 MetaGraphDef 与同名资源相关联,则只保留首个版本。

    必须使用用户指定的标签对每个添加到 SavedModel 的 MetaGraphDef进行标注。这些标签提供了一种方法来识别要加载和恢复的特定MetaGraphDef,以及共享的变量和资源子集。这些标签一般会标注MetaGraphDef的功能(例如服务或训练),有时也会标注特定的硬件方面的信息(如GPU)。

    例如,以下代码展示了使用MeatGraphDef构建SavedModel的典型方法:

     1 export_dir = ...
     2 ...
     3 builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
     4 with tf.Session(graph=tf.Graph()) as sess:
     5   ...
     6   builder.add_meta_graph_and_variables(sess,
     7                                        [tag_constants.TRAINING],
     8                                        signature_def_map=foo_signatures,
     9                                        assets_collection=foo_assets,
    10                                        strip_default_attrs=True)
    11 ...
    12 # Add a second MetaGraphDef for inference.
    13 with tf.Session(graph=tf.Graph()) as sess:
    14   ...
    15   builder.add_meta_graph([tag_constants.SERVING], strip_default_attrs=True)
    16 ...
    17 builder.save()

    通过 strip_default_attrs=True确保前向兼容性

    只有在操作集合没有变化的情况下,遵循以下指南才能带来向前兼容性。

    SavedModelBuilder类允许用户控制在将元图添加到SaverModel软件时,是否必须从NodeDefs剥离默认属性。SavedModelBuilder.add_meta_graph_add_variable和SavedModelBuilder.add_meta_graph 方法都接受控制此行为的布尔标记strip_default_attrs。

    如果strip_default_attrs=False,则导出的tf.MetaGraphDef 将在其所有的 tf.NodeDef实例中具有设为默认值的属性。这样会破坏前向兼容性并出现一系列事件,详情请参阅兼容性指南

    加载Python版SavedModel

    Python版的SavedModel加载器为SavedModel提供了加载和恢复功能。load指令需要以下信息:

    • 要在其中恢复图定义和变量的会话。
    • 用于标识要加载的MetaGraphDef的标签。
    • SavedModel的位置(目录)

    加载后,作为特定MetaModelDef的一部分提供的变量、资源和签名子集将恢复到提供的会话中。

    1 export_dir = ...
    2 ...
    3 with tf.Session(graph=tf.Graph()) as sess:
    4   tf.saved_model.loader.load(sess, [tag_constants.TRAINING], export_dir)
    5   ...

    加载C++版SavedModel

    C++版SavedModel加载器提供了一个可从某个路径加载SavedModel的API(同时允许SessionOptions和RunOptions)。您必须指定要加载的与图相关联的标签。SavedModel加载后的版本称为SavedModelBundle,其中包含MetaGraphDef和加载时所在的会话。

    const string export_dir = ...
    SavedModelBundle bundle;
    ...
    LoadSavedModel(session_options, run_options, export_dir, {kSavedModelTagTrain},&bundle);

    在TensorFlow Serving中加载和提供SavedMoel

    您可以使用TensorFlow Serving Model Sever二进制文件轻松加载和提供SavedModel。请参阅此处说明,了解如何安装服务器,或根据需要创建服务器。

    一旦您的Model  Sever就绪,请运行以下内容:

    1 tensorflow_model_server --port=port-numbers --model_name=your-model-name --model_base_path=your_model_base_path

    将port和model_name标记设为您所选的值。model_base_path标记应为基本目录,每个版本的模型都放置于以数字命名的子目录中给。如果您的模型只有一个版本,只需如下所示的将其放在子目录中即可:*将模型放入/tmp/model/0001*将model_base_path设为/tmp/model

    将模型的不同版本存储在共用基本目录的子目录中(以数字命名)。例如,假设基本目录是/tmp/model。如果您的模型只有一个版本,请将其存储在/tmp/model/0001中。如果您的模型有两个版本,请将第二个版本存储在/tmp/model/0002中,以此类推。将 --model-bash_path标记设为基本目录(在本例中为/tmp/model)。TensorFlow Model Sever将在该基本目录的最大编号的子目录中提供模型。

     标准常量

    SavedModel为各种用例搭建和加载TensorFlow图提供了灵活性。对于常见的用例,SavedModel的API在Python和C++中提供了一组易于重复使用且在各种工具中共享的常量。

    标准MetaGraphDef标签

    您可以使用标签组唯一标识保存在SavedModel中的MetaGraphDef。常用标签的子集如下:

    标准SignatureDef常量

    SignatureDef是一个协议缓冲区,用于定义图中所支持的计算的签名。常用的输入键、输出键和方法名称定义如下:

    搭配Estimator使用SavedModel

    使用CLI检查并执行SavedModel

    SavedModel目录的结构

    当您以SavedModel格式保存模型时,TensorFlow会自动创建一个由以下子目录和文件组成的SavedModel目录:

    1 assets/
    2 assets.extra/
    3 variables/
    4     variables.data-?????-of-?????
    5     variables.index
    6 saved_model.pb|saved_model.pbtxt

    其中:

    • assets 是包含辅助(外部)文件(如词汇表)的子文件夹。资源被复制到SavedModel的位置,并且可以在加载特定的MetaGraphDef时被读取。
    • assets.extra 是一个子文件夹,其中较高级别的库和用户可以添加自己的资源,这些资源与模型共存,但不会被图加载。此子文件夹不由SavedModel库管理。
    • variables 是包含 tf.train.Saver的输出的子文件夹。
    • saved_model.pbsaved_model.pbtxt 是SavedModel协议缓冲区。它作为MetaGraphDef协议缓冲区的图定义。

    单个SavedModel可以表示多个图。在这种情况下,SavedModel中所有图共享一组检查点(变量)和资源。例如,下图显示了一个包含三个MetaGraphDef的SavedModel,它们都共享共享同一组检查点和资源:

    每组图都与一组特定的标记相关联,可在加载或恢复期间方便您识别。

    参考链接:https://tensorflow.google.cn/guide/saved_model#save_and_restore_variables

  • 相关阅读:
    中国的人生路上是紧跟领导就会有回报
    重游三峡广场有感
    假如你没有我
    关于中小型软件企业技术管理的建议(转)
    街客
    游歌乐山有感
    高成就者的反常思维
    漫谈创业和管理-程序员5大思维障碍 (转)
    QQ情缘
    javascript library
  • 原文地址:https://www.cnblogs.com/lfri/p/10363336.html
Copyright © 2011-2022 走看看