zoukankan      html  css  js  c++  java
  • tf.keras.Model使用saved_model,自定义输入输出signature

    环境:tensorflow2.2

    使用tf.keras.Model.save保存saved_model格式时,默认的input和output比较通用,input_1, input2, output_1,output_2

    自定义输入输出名字:

    import tensorflow as tf
    
    sigs = [tf.TensorSpec([None,8], tf.float32, name="a"),
            tf.TensorSpec([None,8], tf.float32, name="b"),
            tf.TensorSpec([None,8], tf.float32, name="c")]
    
    class FullyConnectDnnModel(tf.keras.Model):
      def __init__(self, name):
        super().__init__(name=name)
        self.h1 = tf.keras.layers.Dense(1024, activation='relu')
        self.h2 = tf.keras.layers.Dense(512, activation='relu')
        self.h3 = tf.keras.layers.Dense(256, activation='relu')
        self.h4 = tf.keras.layers.Dense(128, activation='relu')
        self.h5 = tf.keras.layers.Dense(1)
    
      @tf.function(input_signature=[sigs])
      def call(self, emb_layer_list):
        emb_layers = tf.concat(emb_layer_list, axis=1)
        layer1 = self.h1(emb_layers)
        layer2 = self.h2(layer1)
        layer3 = self.h3(layer2)
        layer4 = self.h4(layer3)
        logits = self.h5(layer4)
        predict = tf.nn.sigmoid(logits)
        return {"logits":logits, "predict": predict}
    
    model = FullyConnectDnnModel("test")
    emb_layer_list = []
    for i in range(3):
        emb_layer_list.append(tf.constant(1.0, shape=[4, 8]))
    out = model(emb_layer_list)
    
    model.save("./saved_model")
    

    注意:

    ①call方法的输入是个list,那么input_signature的输入需要是个list[list[tf.TensorSpec]],如果输入是一个tensor,那么input_signature的输入是list[tf.TensorSpec],相当于input_signature必须是了list,list里面是什么需要和call的输入类型对齐.(测试发现tf.2.2版本,keras.Model下的call方法,Input_signature不能传dict,会报错)

    ②call方法可以返回dict,但是官方文档是这样写的.....有误导性:

    后面saved_model的文档又是这样写的.....就很气....:

     

    保存之后执行:

    saved_model_cli show --dir=./saved_model --all
    

     

    参考:

    1.https://www.tensorflow.org/api_docs/python/tf/keras/Model

    2.https://www.tensorflow.org/guide/saved_model?hl=zh-tw#specifying_signatures_during_export

  • 相关阅读:
    oracle之is null和is not null的优化
    oracle命令导出/导入
    Linux环境下后台启动运行jar并设置内存
    阿里云手动安装git客户端
    阿里云安装maven
    python 迭代器
    ThreadPoolExecutor构造器参数详解
    CVE-2020-13957 solr未授权复现
    CVE-2020-9496 apache ofbiz xml-rpc反序列化漏洞分析
    REST API介绍
  • 原文地址:https://www.cnblogs.com/deepllz/p/15408319.html
Copyright © 2011-2022 走看看