zoukankan      html  css  js  c++  java
  • tf.keras.Model使用saved_model,自定义输入输出signature

    环境:tensorflow2.2

    使用tf.keras.Model.save保存saved_model格式时,默认的input和output比较通用,input_1, input2, output_1,output_2

    自定义输入输出名字:

    import tensorflow as tf
    
    sigs = [tf.TensorSpec([None,8], tf.float32, name="a"),
            tf.TensorSpec([None,8], tf.float32, name="b"),
            tf.TensorSpec([None,8], tf.float32, name="c")]
    
    class FullyConnectDnnModel(tf.keras.Model):
      def __init__(self, name):
        super().__init__(name=name)
        self.h1 = tf.keras.layers.Dense(1024, activation='relu')
        self.h2 = tf.keras.layers.Dense(512, activation='relu')
        self.h3 = tf.keras.layers.Dense(256, activation='relu')
        self.h4 = tf.keras.layers.Dense(128, activation='relu')
        self.h5 = tf.keras.layers.Dense(1)
    
      @tf.function(input_signature=[sigs])
      def call(self, emb_layer_list):
        emb_layers = tf.concat(emb_layer_list, axis=1)
        layer1 = self.h1(emb_layers)
        layer2 = self.h2(layer1)
        layer3 = self.h3(layer2)
        layer4 = self.h4(layer3)
        logits = self.h5(layer4)
        predict = tf.nn.sigmoid(logits)
        return {"logits":logits, "predict": predict}
    
    model = FullyConnectDnnModel("test")
    emb_layer_list = []
    for i in range(3):
        emb_layer_list.append(tf.constant(1.0, shape=[4, 8]))
    out = model(emb_layer_list)
    
    model.save("./saved_model")
    

    注意:

    ①call方法的输入是个list,那么input_signature的输入需要是个list[list[tf.TensorSpec]],如果输入是一个tensor,那么input_signature的输入是list[tf.TensorSpec],相当于input_signature必须是了list,list里面是什么需要和call的输入类型对齐.(测试发现tf.2.2版本,keras.Model下的call方法,Input_signature不能传dict,会报错)

    ②call方法可以返回dict,但是官方文档是这样写的.....有误导性:

    后面saved_model的文档又是这样写的.....就很气....:

     

    保存之后执行:

    saved_model_cli show --dir=./saved_model --all
    

     

    参考:

    1.https://www.tensorflow.org/api_docs/python/tf/keras/Model

    2.https://www.tensorflow.org/guide/saved_model?hl=zh-tw#specifying_signatures_during_export

  • 相关阅读:
    Find the Smallest K Elements in an Array
    Count of Smaller Number
    Number of Inversion Couple
    Delete False Elements
    Sort Array
    Tree Diameter
    Segment Tree Implementation
    Java Programming Mock Tests
    zz Morris Traversal方法遍历二叉树(非递归,不用栈,O(1)空间)
    Algorithm about SubArrays & SubStrings
  • 原文地址:https://www.cnblogs.com/deepllz/p/15408319.html
Copyright © 2011-2022 走看看