zoukankan      html  css  js  c++  java
  • tensorflow加载多个计算图的冲突解决

    需求:顺序加载多个计算图时,会导致第二个计算图后变量  不可用,在程序初始化中解决该问题(一下代码没有做优化,请读者自行修正)

    class BertEncoder(object):
        """ model
        """
        def __init__(self, OUTPUT_GRAPH, OUT_TENSOR):
            self.max_length = 30
            self.tokenizer = TOKENIZER
            self.out_graph = os.path.join(CURRENT_DIR, "pb_model", OUTPUT_GRAPH)
            self.out_tensor = OUT_TENSOR
            self.model_graph = {}
            graph = tf.Graph()
            with graph.as_default():
                self.model_graph['output_graph_def'] = tf.compat.v1.GraphDef()
                with open(self.out_graph, "rb") as f:
                    self.model_graph['output_graph_def'].ParseFromString(f.read())
                self.model_graph['sess'] = tf.Session(graph=graph)
            with self.model_graph['sess'].as_default():
                with graph.as_default():
                    self.model_graph['sess'].run(tf.compat.v1.global_variables_initializer())
                    tf.import_graph_def(self.model_graph['output_graph_def'], name="")
                    self.input_ids_p = self.model_graph['sess'].graph.get_tensor_by_name("input_ids:0")
                    self.input_mask_p = self.model_graph['sess'].graph.get_tensor_by_name("input_mask:0")
                    self.output_tensor = self.model_graph['sess'].graph.get_tensor_by_name(self.out_tensor)
    
    
        def predict(self, to_predict):
            """pb predict
            """
            sentence = [each.lower() for each in to_predict]
            input_ids, input_mask, = self.convert(sentence)
            feed_dict = {self.input_ids_p: input_ids,
                         self.input_mask_p: input_mask}
            sess = self.model_graph['sess']
            output_emb = sess.run(self.output_tensor, feed_dict)
            return output_emb
    时刻记着自己要成为什么样的人!
  • 相关阅读:
    表达式求值
    火柴排队(归并)
    POJ 3254 压缩状态DP
    ZOJ 3471 压缩状态DP
    Boost IPC Persistence Of Interprocess Mechanisms 例子
    TCO 2014 Round 1A
    Google Code Jam 2014 Qualification 题解
    STL set_difference set_intersection set_union 操作
    b_zj_特征提取(map记录上一个特征运动的次数)
    b_zj_最大连续的相同字符子串的长度(双指针+找突破点)
  • 原文地址:https://www.cnblogs.com/demo-deng/p/14695124.html
Copyright © 2011-2022 走看看