zoukankan      html  css  js  c++  java
  • tf多个tensor输出并完成加载

    转换成pb模型,设定多输出

    def fun():
        """ 保留bert第一层和第二层信息"""
        OUTPUT_GRAPH = 'pb_model/query_encoder.pb'
        output_node = ["loss/Softmax", "bert/pooler/dense/Tanh", "Mean"]
        ckpt_model = r'best_ckpt'
        bert_config_file = r'model/chinese_L-12_H-768_A-12/bert_config.json'
        max_seq_length = 10
        gpu_config = tf.ConfigProto()
        gpu_config.gpu_options.allow_growth = True
        sess = tf.Session(config=gpu_config)
        graph = tf.get_default_graph()
        with graph.as_default():
            print("going to restore checkpoint")
            input_ids_p = tf.placeholder(tf.int32, [None, max_seq_length], name="input_ids")
            input_mask_p = tf.placeholder(tf.int32, [None, max_seq_length], name="input_mask")
            bert_config = modeling.BertConfig.from_json_file(bert_config_file)
            (loss, per_example_loss, logits, probabilities, out) = create_model(bert_config=bert_config, is_training=False, input_ids=input_ids_p, input_mask=input_mask_p,
                segment_ids=None, labels=None, num_labels=len(label_list), use_one_hot_embeddings=False, fp16=FLAGS.use_fp16)
            saver = tf.train.Saver()
            saver.restore(sess, tf.train.latest_checkpoint(ckpt_model))
            graph = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, output_node)
            with tf.gfile.GFile(OUTPUT_GRAPH, "wb") as f:
                f.write(graph.SerializeToString())
            print('extract vector pb model saved!')

    推理部分

    class BertEncoder(object):
        """ model
        """
        def __init__(self, OUTPUT_GRAPH):
            self.max_length = 30
            self.tokenizer = TOKENIZER
            self.out_graph = os.path.join(CURRENT_DIR, "pb_model", OUTPUT_GRAPH)
            self.model_graph = {}
            graph = tf.Graph()
            with graph.as_default():
                self.model_graph['output_graph_def'] = tf.compat.v1.GraphDef()
                with open(self.out_graph, "rb") as f:
                    self.model_graph['output_graph_def'].ParseFromString(f.read())
                self.model_graph['sess'] = tf.Session(graph=graph)
            with self.model_graph['sess'].as_default():
                with graph.as_default():
                    self.model_graph['sess'].run(tf.compat.v1.global_variables_initializer())
                    _input_1, _input2, _output_1, _output_2, _cls_out = tf.import_graph_def(self.model_graph['output_graph_def'],
                                                                                 return_elements=[INPUT_1, INPUT_2, SOFTMAX_OUTPUT, FIRST_LAST_OUTPUT, CLS_OUTPUT])
                    self.input_ids_p = self.model_graph['sess'].graph.get_tensor_by_name("import/input_ids:0")
                    self.input_mask_p = self.model_graph['sess'].graph.get_tensor_by_name("import/input_mask:0")
                    self.output_1 = self.model_graph['sess'].graph.get_tensor_by_name("import/loss/Softmax:0")
                    self.output_2 = self.model_graph['sess'].graph.get_tensor_by_name("import/Mean:0")
                    self.output_3 = self.model_graph['sess'].graph.get_tensor_by_name("import/bert/pooler/dense/Tanh:0")
    时刻记着自己要成为什么样的人!
  • 相关阅读:
    MSSQL2005和Access在SQL的某一种写法上的区别。update的一种写法不一致。
    博客园 记录 了解多一点
    马克斯4.0 采集规则的编写
    谷歌代码托管 GoogleCode中 关于 版本的一个写法
    晒晒名企大公司的工资收入
    Asp.net中DataBinder.Eval用法的总结
    Mastering Debugging in Visual Studio 2010 A Beginner's Guide
    Solution Configuration but not Platform in VS2010 Toolbar
    window.showdialog完全手册,解决模态窗口,传值和返回值问题
    从此不再惧怕URI编码:JavaScript及C# URI编码详解
  • 原文地址:https://www.cnblogs.com/demo-deng/p/14746369.html
Copyright © 2011-2022 走看看