简介
tf.train.Saver 类提供了保存和恢复模型的方法。通过 tf.saved_model.simple_save 函数可以轻松地保存适合投入使用的模型。Estimator会自动保存和恢复 model_dir 中的变量。
保存和恢复变量
TensorFlow变量是表示由程序操作的共享持久状态的最佳方法。tf.train.Saver 构造函数会针对图中的所有变量或指定列表的变量将 save 和 restore 操作添加到图中。Saver对象提供了运行这些操作的方法,并指定写入或读取检查点文件的路径。
Saver 会恢复已经在模型中定义的所有变量。如果您在不知道如何构件图的情况下加载模型(例如,您要编写用于加载各类模型的通用程序),那么请阅读本文档后面的保存和恢复模型概述部分。
TensorFlow将变量保存在二进制检查点文件中,这类文件会将变量名称映射到张量值。
注意:TensorFlow 模型文件是代码。请注意不可信的代码。详情请参阅安全地使用 TensorFlow。
保存变量
创建Saver(使用 tf.train.Saver())来管理模型中的所有变量。例如,以下代码展示了如何调用 tf.train.Saver.save 方法以将变量保存到检查点文件中:
1 # Create some variables. 2 v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer) 3 v2 = tf.get_variable("v2", shape=[5], initializer = tf.zeros_initializer) 4 5 inc_v1 = v1.assign(v1+1) 6 dec_v2 = v2.assign(v2-1) 7 8 # Add an op to initialize the variables. 9 init_op = tf.global_variables_initializer() 10 11 # Add ops to save and restore all the variables. 12 saver = tf.train.Saver() 13 14 # Later, launch the model, initialize the variables, do some work, and save the variables to disk. 15 with tf.Session() as sess: 16 sess.run(init_op) 17 # Do some work with the model. 18 inc_v1.op.run() 19 dec_v2.op.run() 20 # Save the variables to disk. 21 save_path = saver.save(sess, "/tmp/model.ckpt") 22 print("Model saved in path: %s" % save_path)
恢复变量
tf.train.Saver 对象不仅将变量保存在检查点文件中,还将恢复变量。请注意,当您恢复变量时,您不必事先将其初始化。例如,以下代码段展示了如何调用 tf.train.Saver.restore 方法以从检查点文件中恢复变量:
1 tf.reset_default_graph() 2 3 # Create some variables. 4 v1 = tf.get_variable("v1", shape=[3]) 5 v2 = tf.get_variable("v2", shape=[5]) 6 7 # Add ops to save and restore all the variables. 8 saver = tf.train.Saver() 9 10 # Later, launch the model, use the saver to restore variables from disk, and 11 # do some work with the model. 12 with tf.Session() as sess: 13 # Restore variables from disk. 14 saver.restore(sess, "/tmp/model.ckpt") 15 print("Model restored.") 16 # Check the values of the variables 17 print("v1 : %s" % v1.eval()) 18 print("v2 : %s" % v2.eval())
注意:并没有名为 /tmp/model.ckpt
的实体文件。它是为检查点创建的文件名的前缀。用户仅与前缀(而非检查点实体文件)互动。
选择要保存和恢复的变量
如果您没有向 tf.train.Saver()传递任何参数,则Saver会处理图中的所有变量。每个变量都保存在创建变量时所传递的名称下。
在检查点文件中明确指定变量名称的这种做法有时非常有用。例如,您可能已经使用名为“weights” 的变量训练了一个模型,而您想要将该变量的值恢复到名为“params”的变量中。
有时候,仅保存和恢复模型使用的变量子集也会很有裨益。例如,您可能已经训练了一个五层的神经网络,现在您训练一个六层的新模型,并重用该五层的现有权重。您可以使用Saver只恢复这前五层的权重。
您可以通过向 tf.train.Saver() 构造函数传递以下任一内容,轻松指定要保存或加载的名称和变量:
- 变量列表(将以其本身的名称保存)。
- Python字典,其中,键是要使用的名称,键值是要管理的变量。
继续前面所示的保存/恢复示例:
1 tf.reset_default_graph() 2 # Create some variables. 3 v1 = tf.get_variable("v1", [3], initializer = tf.zeros_initializer) 4 v2 = tf.get_variable("v2", [5], initializer = tf.zeros_initializer) 5 6 # Add ops to save and restore only `v2` using the name "v2" 7 saver = tf.train.Saver({"v2": v2}) 8 9 # Use the saver object normally after that. 10 with tf.Session() as sess: 11 # Initialize v1 since the saver will not. 12 v1.initializer.run() 13 saver.restore(sess, "/tmp/model.ckpt") 14 15 print("v1 : %s" % v1.eval()) 16 print("v2 : %s" % v2.eval())
注意:
- 如果要保存和恢复模型变量的不同子集,您可以根据需要创建任意数量的 Saver对象。同一个变量可以列在多个Saver对象中,变量的值只有在Saver.restore()方法运行时才会更改。
- 如果您在会话开始时仅恢复一部分模型变量,则必须为其它变量运行初始化操作。
- 要检查某个检查点中的变量,您可以使用 inspect_checkpoint 库,尤其是 print_tensors_in_checkpoint_file 函数。默认情况下,Saver会针对每个变量使用 tf.Variable.name 属性的值。但是,当您创建Saver对象时,您可以选择为检查点中的变量选择名称。
检查某个检查点中的变量
我们可以使用 inspect_checkpoint 库快速检查某个检查点中的变量。
继续前面所示的保存/恢复示例:
1 # import the inspect_checkpoint library 2 from tensorflow.python.tools import inspect_checkpoint as chkp 3 4 # print all tensors in checkpoint file 5 chkp.print_tensors_in_checkpoint_file("/tmp/model.ckpt", tensor_name='', all_tensors=True) 6 7 # tensor_name: v1 8 # [ 1. 1. 1.] 9 # tensor_name: v2 10 # [-1. -1. -1. -1. -1.] 11 12 # print only tensor v1 in checkpoint file 13 chkp.print_tensors_in_checkpoint_file("/tmp/model.ckpt", tensor_name='v1', all_tensors=False) 14 15 # tensor_name: v1 16 # [ 1. 1. 1.] 17 18 # print only tensor v2 in checkpoint file 19 chkp.print_tensors_in_checkpoint_file("/tmp/model.ckpt", tensor_name='v2', all_tensors=False) 20 21 # tensor_name: v2 22 # [-1. -1. -1. -1. -1.]
保存和恢复模型
使用 SavedModel 保存和加载模型-变量、图和图的元数据。SavedModel是一种独立于语言且可恢复的神秘序列化格式,使较高级别的系统和工具可以创建、使用和转换TensorFlow模型。TensorFlow提供了多种与 SaverModel交互的方式,包括 tf.saved_model API、tf.estimator.Estimator和命令行界面。
构建和加载SavedModel
简单保存
创建SavedModel 的最简单的方法使使用 tf.saved_model.simple_save 函数:
1 simple_save(session, 2 export_dir, 3 inputs={"x": x, "y": y}, 4 outputs={"z": z})
这样可以配置 SavedModel,使其能够通过 TensorFlow Serving进行加载,并支持Predict API。要访问classify API、regress API或者multi-inference API,请使用手动SavedModel builder API或 tf.estimator.Estimator。
手动构建按SavedModel
如果您的用例不在 tf.saved_model.simple_save涵盖范围内,请手动 builder API 创建SaverModel。
tf.saved_model.builder.SavedModelBuilder 类提供了保存多个 MetaGraphDef 的功能。MetaGraph是一种数据流图,并包含相关变量、资源和签名。MetaGraphDef是MetaGraph的协议缓冲区表示法。签名是一组与图有关的输入和输出。
如果需要将资源保存并写入或复制到磁盘,则可以在首次添加 MetaGraphDef时提供这些资源。如果多个 MetaGraphDef 与同名资源相关联,则只保留首个版本。
必须使用用户指定的标签对每个添加到 SavedModel 的 MetaGraphDef进行标注。这些标签提供了一种方法来识别要加载和恢复的特定MetaGraphDef,以及共享的变量和资源子集。这些标签一般会标注MetaGraphDef的功能(例如服务或训练),有时也会标注特定的硬件方面的信息(如GPU)。
例如,以下代码展示了使用MeatGraphDef构建SavedModel的典型方法:
1 export_dir = ... 2 ... 3 builder = tf.saved_model.builder.SavedModelBuilder(export_dir) 4 with tf.Session(graph=tf.Graph()) as sess: 5 ... 6 builder.add_meta_graph_and_variables(sess, 7 [tag_constants.TRAINING], 8 signature_def_map=foo_signatures, 9 assets_collection=foo_assets, 10 strip_default_attrs=True) 11 ... 12 # Add a second MetaGraphDef for inference. 13 with tf.Session(graph=tf.Graph()) as sess: 14 ... 15 builder.add_meta_graph([tag_constants.SERVING], strip_default_attrs=True) 16 ... 17 builder.save()
通过 strip_default_attrs=True确保前向兼容性
只有在操作集合没有变化的情况下,遵循以下指南才能带来向前兼容性。
SavedModelBuilder类允许用户控制在将元图添加到SaverModel软件时,是否必须从NodeDefs剥离默认属性。SavedModelBuilder.add_meta_graph_add_variable和SavedModelBuilder.add_meta_graph 方法都接受控制此行为的布尔标记strip_default_attrs。
如果strip_default_attrs=False,则导出的tf.MetaGraphDef 将在其所有的 tf.NodeDef实例中具有设为默认值的属性。这样会破坏前向兼容性并出现一系列事件,详情请参阅兼容性指南。
加载Python版SavedModel
Python版的SavedModel加载器为SavedModel提供了加载和恢复功能。load指令需要以下信息:
- 要在其中恢复图定义和变量的会话。
- 用于标识要加载的MetaGraphDef的标签。
- SavedModel的位置(目录)
加载后,作为特定MetaModelDef的一部分提供的变量、资源和签名子集将恢复到提供的会话中。
1 export_dir = ... 2 ... 3 with tf.Session(graph=tf.Graph()) as sess: 4 tf.saved_model.loader.load(sess, [tag_constants.TRAINING], export_dir) 5 ...
加载C++版SavedModel
C++版SavedModel加载器提供了一个可从某个路径加载SavedModel的API(同时允许SessionOptions和RunOptions)。您必须指定要加载的与图相关联的标签。SavedModel加载后的版本称为SavedModelBundle,其中包含MetaGraphDef和加载时所在的会话。
const string export_dir = ...
SavedModelBundle bundle;
...
LoadSavedModel(session_options, run_options, export_dir, {kSavedModelTagTrain},&bundle);
在TensorFlow Serving中加载和提供SavedMoel
您可以使用TensorFlow Serving Model Sever二进制文件轻松加载和提供SavedModel。请参阅此处说明,了解如何安装服务器,或根据需要创建服务器。
一旦您的Model Sever就绪,请运行以下内容:
1 tensorflow_model_server --port=port-numbers --model_name=your-model-name --model_base_path=your_model_base_path
将port和model_name标记设为您所选的值。model_base_path标记应为基本目录,每个版本的模型都放置于以数字命名的子目录中给。如果您的模型只有一个版本,只需如下所示的将其放在子目录中即可:*将模型放入/tmp/model/0001*将model_base_path设为/tmp/model
将模型的不同版本存储在共用基本目录的子目录中(以数字命名)。例如,假设基本目录是/tmp/model。如果您的模型只有一个版本,请将其存储在/tmp/model/0001中。如果您的模型有两个版本,请将第二个版本存储在/tmp/model/0002中,以此类推。将 --model-bash_path标记设为基本目录(在本例中为/tmp/model)。TensorFlow Model Sever将在该基本目录的最大编号的子目录中提供模型。
标准常量
SavedModel为各种用例搭建和加载TensorFlow图提供了灵活性。对于常见的用例,SavedModel的API在Python和C++中提供了一组易于重复使用且在各种工具中共享的常量。
标准MetaGraphDef标签
您可以使用标签组唯一标识保存在SavedModel中的MetaGraphDef。常用标签的子集如下:
标准SignatureDef常量
SignatureDef是一个协议缓冲区,用于定义图中所支持的计算的签名。常用的输入键、输出键和方法名称定义如下:
搭配Estimator使用SavedModel
使用CLI检查并执行SavedModel
SavedModel目录的结构
当您以SavedModel格式保存模型时,TensorFlow会自动创建一个由以下子目录和文件组成的SavedModel目录:
1 assets/ 2 assets.extra/ 3 variables/ 4 variables.data-?????-of-????? 5 variables.index 6 saved_model.pb|saved_model.pbtxt
其中:
- assets 是包含辅助(外部)文件(如词汇表)的子文件夹。资源被复制到SavedModel的位置,并且可以在加载特定的MetaGraphDef时被读取。
- assets.extra 是一个子文件夹,其中较高级别的库和用户可以添加自己的资源,这些资源与模型共存,但不会被图加载。此子文件夹不由SavedModel库管理。
- variables 是包含 tf.train.Saver的输出的子文件夹。
- saved_model.pb或saved_model.pbtxt 是SavedModel协议缓冲区。它作为MetaGraphDef协议缓冲区的图定义。
单个SavedModel可以表示多个图。在这种情况下,SavedModel中所有图共享一组检查点(变量)和资源。例如,下图显示了一个包含三个MetaGraphDef的SavedModel,它们都共享共享同一组检查点和资源:
每组图都与一组特定的标记相关联,可在加载或恢复期间方便您识别。
参考链接:https://tensorflow.google.cn/guide/saved_model#save_and_restore_variables