zoukankan      html  css  js  c++  java
  • tensorflow 优化图

    当我们把训练好的tensorflow训练图拿来进行预测时,会有多个训练时生成的节点,这些节点是不必要的,我们需要在预测的时候进行删除。

    下面以bert的图为例,进行优化

        def optimize_graph(self, checkpoint_file, model_config):
            import json
            tf = self.import_tf()
            from tensorflow.python.tools.optimize_for_inference_lib import optimize_for_inference
    
            config = tf.ConfigProto(device_count={'GPU': 0}, allow_soft_placement=True)
    
            init_checkpoint = checkpoint_file
    
            with tf.gfile.GFile(model_config, 'r') as f:
                bert_config = modeling.BertConfig.from_dict(json.load(f))
    
            input_ids = tf.placeholder(tf.int32, (None, MAX_SEQ_LENGTH), 'input_ids')
            input_mask = tf.placeholder(tf.int32, (None, MAX_SEQ_LENGTH), 'input_mask')
            input_type_ids = tf.placeholder(tf.int32, (None, MAX_SEQ_LENGTH), 'input_type_ids')
    
            import contextlib
            jit_scope = contextlib.suppress
    
            with jit_scope():
                input_tensors = [input_ids, input_mask, input_type_ids]
                model = modeling.BertModel(
                    config=bert_config,
                    is_training=False,
                    input_ids=input_ids,
                    input_mask=input_mask,
                    token_type_ids=input_type_ids,
                    use_one_hot_embeddings=False)
    
                tvars = tf.trainable_variables()
    
                (assignment_map, initialized_variable_names
                 ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
    
                # get output tensor
                tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
                reader = tf.train.NewCheckpointReader(init_checkpoint)
                output_weights = reader.get_tensor('output_weights')
                output_bias = reader.get_tensor('output_bias')
                output_layers = model.get_pooled_output()
                pooled = tf.nn.softmax(tf.nn.bias_add(tf.matmul(output_layers, output_weights, transpose_b=True),
                                                      output_bias))
                pooled = tf.identity(pooled, 'final_encodes')
    
                output_tensors = [pooled]
                tmp_g = tf.get_default_graph().as_graph_def()
    
                # write graph to file
                with tf.Session(config=config) as sess:
                    sess.run(tf.global_variables_initializer())
                    tmp_g = tf.graph_util.convert_variables_to_constants(sess, tmp_g, [n.name[:-2] for n in output_tensors])
                    dtypes = [n.dtype for n in input_tensors]
                    tmp_g = optimize_for_inference(
                        tmp_g,
                        [n.name[:-2] for n in input_tensors],
                        [n.name[:-2] for n in output_tensors],
                        [dtype.as_datatype_enum for dtype in dtypes],
                        False)
    
                    import tempfile
                    tmp_file = tempfile.NamedTemporaryFile('w', delete=False, dir=r'optimize').name
                    with tf.gfile.GFile(tmp_file, 'wb') as f:
                        f.write(tmp_g.SerializeToString())
    
                    return tmp_file

    返回一个gfile类型的文件,我们可以像原来导入模型文件时,恢复图,不过这个图是优化过的。

  • 相关阅读:
    jstl 部分标签
    Maven pom.xml 元素配置说明(一)
    spring 参数绑定
    mysql 索引
    ArrayList和HashSet的Contains()方法(转)
    每日记载内容总结44
    剑指offer42:不用加减乘除做加法
    动态规划常见题型
    华为机试-统计每个月兔子的总数
    华为机试-字符串合并处理
  • 原文地址:https://www.cnblogs.com/callyblog/p/10388487.html
Copyright © 2011-2022 走看看