zoukankan      html  css  js  c++  java
  • 导出pb模型之后测试的python代码

    链接:https://blog.csdn.net/thriving_fcl/article/details/75213361

    saved_model模块主要用于TensorFlow Serving。TF Serving是一个将训练好的模型部署至生产环境的系统,主要的优点在于可以保持Server端与API不变的情况下,部署新的算法或进行试验,同时还有很高的性能。

    保持Server端与API不变有什么好处呢?有很多好处,我只从我体会的一个方面举例子说明一下,比如我们需要部署一个文本分类模型,那么输入和输出是可以确定的,输入文本,输出各类别的概率或类别标签。为了得到较好的效果,我们可能想尝试很多不同的模型,CNN,RNN,RCNN等,这些模型训练好保存下来以后,在inference阶段需要重新载入这些模型,我们希望的是inference的代码有一份就好,也就是使用新模型的时候不需要针对新模型来修改inference的代码。这应该如何实现呢?

    在TensorFlow 模型保存/载入的两种方法中总结过。
    1. 仅用Saver来保存/载入变量。这个方法显然不行,仅保存变量就必须在inference的时候重新定义Graph(定义模型),这样不同的模型代码肯定要修改。即使同一种模型,参数变化了,也需要在代码中有所体现,至少需要一个配置文件来同步,这样就很繁琐了。
    2. 使用tf.train.import_meta_graph导入graph信息并创建Saver, 再使用Saver restore变量。相比第一种,不需要重新定义模型,但是为了从graph中找到输入输出的tensor,还是得用graph.get_tensor_by_name()来获取,也就是还需要知道在定义模型阶段所赋予这些tensor的名字。如果创建各模型的代码都是同一个人完成的,还相对好控制,强制这些输入输出的命名都一致即可。如果是不同的开发者,要在创建模型阶段就强制tensor的命名一致就比较困难了。这样就不得不再维护一个配置文件,将需要获取的tensor名称写入,然后从配置文件中读取该参数。

    经过上面的分析发现,要实现inference的代码统一,使用原来的方法也是可以的,只不过TensorFlow官方提供了更好的方法,并且这个方法不仅仅是解决这个问题,所以还是得学习使用saved_model这个模块。
    saved_model 保存/载入模型

    先列出会用到的API

    class tf.saved_model.builder.SavedModelBuilder

    # 初始化方法
    __init__(export_dir)

    # 导入graph与变量信息
    add_meta_graph_and_variables(
        sess,
        tags,
        signature_def_map=None,
        assets_collection=None,
        legacy_init_op=None,
        clear_devices=False,
        main_op=None
    )

    # 载入保存好的模型
    tf.saved_model.loader.load(
        sess,
        tags,
        export_dir,
        **saver_kwargs
    )

        1
        2
        3
        4
        5
        6
        7
        8
        9
        10
        11
        12
        13
        14
        15
        16
        17
        18
        19
        20
        21
        22
        23

    (1) 最简单的场景,只是保存/载入模型
    保存

    要保存一个已经训练好的模型,使用下面三行代码就可以了。

    builder = tf.saved_model.builder.SavedModelBuilder(saved_model_dir)
    builder.add_meta_graph_and_variables(sess, ['tag_string'])
    builder.save()

        1
        2
        3

    首先构造SavedModelBuilder对象,初始化方法只需要传入用于保存模型的目录名,目录不用预先创建。

    add_meta_graph_and_variables方法导入graph的信息以及变量,这个方法假设变量都已经初始化好了,对于每个SavedModelBuilder这个方法一定要执行一次用于导入第一个meta graph。

    第一个参数传入当前的session,包含了graph的结构与所有变量。

    第二个参数是给当前需要保存的meta graph一个标签,标签名可以自定义,在之后载入模型的时候,需要根据这个标签名去查找对应的MetaGraphDef,找不到就会报如RuntimeError: MetaGraphDef associated with tags 'foo' could not be found in SavedModel这样的错。标签也可以选用系统定义好的参数,如tf.saved_model.tag_constants.SERVING与tf.saved_model.tag_constants.TRAINING。

    save方法就是将模型序列化到指定目录底下。

    保存好以后到saved_model_dir目录下,会有一个saved_model.pb文件以及variables文件夹。顾名思义,variables保存所有变量,saved_model.pb用于保存模型结构等信息。
    载入

    使用tf.saved_model.loader.load方法就可以载入模型。如

    meta_graph_def = tf.saved_model.loader.load(sess, ['tag_string'], saved_model_dir)

        1

    第一个参数就是当前的session,第二个参数是在保存的时候定义的meta graph的标签,标签一致才能找到对应的meta graph。第三个参数就是模型保存的目录。

    load完以后,也是从sess对应的graph中获取需要的tensor来inference。如

    x = sess.graph.get_tensor_by_name('input_x:0')
    y = sess.graph.get_tensor_by_name('predict_y:0')

    # 实际的待inference的样本
    _x = ...
    sess.run(y, feed_dict={x: _x})

        1
        2
        3
        4
        5
        6

    这样和之前的第二种方法一样,也是要知道tensor的name。那么如何可以在不知道tensor name的情况下使用呢? 那就需要给add_meta_graph_and_variables方法传入第三个参数,signature_def_map。
    (2) 使用SignatureDef

    关于SignatureDef我的理解是,它定义了一些协议,对我们所需的信息进行封装,我们根据这套协议来获取信息,从而实现创建与使用模型的解耦。SignatureDef的结构以及相关详细的文档在:https://github.com/tensorflow/serving/blob/master/tensorflow_serving/g3doc/signature_defs.md

    相关API

    # 构建signature
    tf.saved_model.signature_def_utils.build_signature_def(
        inputs=None,
        outputs=None,
        method_name=None
    )

    # 构建tensor info
    tf.saved_model.utils.build_tensor_info(tensor)

        1
        2
        3
        4
        5
        6
        7
        8
        9

    SignatureDef,将输入输出tensor的信息都进行了封装,并且给他们一个自定义的别名,所以在构建模型的阶段,可以随便给tensor命名,只要在保存训练好的模型的时候,在SignatureDef中给出统一的别名即可。

    TensorFlow的关于这部分的例子中用到了不少signature_constants,这些constants的用处主要是提供了一个方便统一的命名。在我们自己理解SignatureDef的作用的时候,可以先不用管这些,遇到需要命名的时候,想怎么写怎么写。
    保存

    假设定义模型输入的别名为“input_x”,输出的别名为“output” ,使用SignatureDef的代码如下

    builder = tf.saved_model.builder.SavedModelBuilder(saved_model_dir)
    # x 为输入tensor, keep_prob为dropout的prob tensor
    inputs = {'input_x': tf.saved_model.utils.build_tensor_info(x),
                'keep_prob': tf.saved_model.utils.build_tensor_info(keep_prob)}

    # y 为最终需要的输出结果tensor
    outputs = {'output' : tf.saved_model.utils.build_tensor_info(y)}

    signature = tf.saved_model.signature_def_utils.build_signature_def(inputs, outputs, 'test_sig_name')

    builder.add_meta_graph_and_variables(sess, ['test_saved_model'], {'test_signature':signature})
    builder.save()

        1
        2
        3
        4
        5
        6
        7
        8
        9
        10
        11
        12

    上述inputs增加一个keep_prob是为了说明inputs可以有多个, build_tensor_info方法将tensor相关的信息序列化为TensorInfo protocol buffer。

    inputs,outputs都是dict,key是我们约定的输入输出别名,value就是对具体tensor包装得到的TensorInfo。

    然后使用build_signature_def方法构建SignatureDef,第三个参数method_name暂时先随便给一个。

    创建好的SignatureDef是用在add_meta_graph_and_variables的第三个参数signature_def_map中,但不是直接传入SignatureDef对象。事实上signature_def_map接收的是一个dict,key是我们自己命名的signature名称,value是SignatureDef对象。
    载入

    载入与使用的代码如下


    ## 略去构建sess的代码

    signature_key = 'test_signature'
    input_key = 'input_x'
    output_key = 'output'

    meta_graph_def = tf.saved_model.loader.load(sess, ['test_saved_model'], saved_model_dir)
    # 从meta_graph_def中取出SignatureDef对象
    signature = meta_graph_def.signature_def

    # 从signature中找出具体输入输出的tensor name
    x_tensor_name = signature[signature_key].inputs[input_key].name
    y_tensor_name = signature[signature_key].outputs[output_key].name

    # 获取tensor 并inference
    x = sess.graph.get_tensor_by_name(x_tensor_name)
    y = sess.graph.get_tensor_by_name(y_tensor_name)

    # _x 实际输入待inference的data
    sess.run(y, feed_dict={x:_x})

        1
        2
        3
        4
        5
        6
        7
        8
        9
        10
        11
        12
        13
        14
        15
        16
        17
        18
        19
        20
        21

    从上面两段代码可以知道,我们只需要约定好输入输出的别名,在保存模型的时候使用这些别名创建signature,输入输出tensor的具体名称已经完全隐藏,这就实现创建模型与使用模型的解耦。
    ---------------------
    作者:thriving_fcl
    来源:CSDN
    原文:https://blog.csdn.net/thriving_fcl/article/details/75213361?utm_source=copy
    版权声明:本文为博主原创文章,转载请附上博文链接!

  • 相关阅读:
    Ubuntu 16 安装redis客户端
    crontab 参数详解
    PHP模拟登录发送闪存
    Nginx配置端口访问的网站
    Linux 增加对外开放的端口
    Linux 实用指令之查看端口开启情况
    无敌的极路由
    不同的域名可以指向同一个项目
    MISCONF Redis is configured to save RDB snapshots, but is currently not able to persist on disk. Commands that may modify the data set are disabled. Please check Redis logs for details about the error
    Redis 创建多个端口
  • 原文地址:https://www.cnblogs.com/jerrybaby/p/9804706.html
Copyright © 2011-2022 走看看