zoukankan      html  css  js  c++  java
  • 玩烂bert--集成输出预测分类、特征向量、参数冻结、推理

    功能:

    1)微调模型后,下游任务在此模型上继续开发模型,冻结12层。方法:加载微调模型后(不是google原始ckpt),在custom_optimization.py中仅梯度更新需要的variable

        update_var_list = []
        tvars = tf.trainable_variables()
        for v in tvars:
            if "my_variable" in v.name:
                update_var_list.append(v)
        # gvs = optimizer.compute_gradients(loss, tvars)
        gvs = optimizer.compute_gradients(loss, update_var_list)

    2)顺带输出每个字符的编码向量(768 dim),vector来源根据自身需求选取,供下游相似度查询、检索使用,直接get出

    本次记录ckpt转pb主要代码:

    def bert_first_last_layer():
        """ 保留bert第一层和第二层信息"""
        OUTPUT_GRAPH = 'pb_model/my_model.pb'
        # output_node = ["bert/pooler/dense/Tanh", "Mean"]
        output_node = ["loss/Softmax", "bert/pooler/dense/Tanh", "Mean", "loss/Softmax_1"]
        ckpt_model = r'new_ckpt'
        bert_config_file = r'model/chinese_L-12_H-768_A-12/bert_config.json'
        max_seq_length = 350
        confidence_labels_length = 2
        gpu_config = tf.ConfigProto()
        gpu_config.gpu_options.allow_growth = True
        sess = tf.Session(config=gpu_config)
        graph = tf.get_default_graph()
        with graph.as_default():
            print("going to restore checkpoint")
            input_ids_p = tf.placeholder(tf.int32, [None, max_seq_length], name="input_ids")
            input_mask_p = tf.placeholder(tf.int32, [None, max_seq_length], name="input_mask")
            bert_config = modeling.BertConfig.from_json_file(bert_config_file)
            (loss, per_example_loss, logits, probabilities) = create_model(
                bert_config=bert_config, is_training=False, input_ids=input_ids_p, input_mask=input_mask_p,
                segment_ids=None, labels=None, num_labels=confidence_labels_length, use_one_hot_embeddings=False, fp16=FLAGS.use_fp16)
            saver = tf.train.Saver()
            saver.restore(sess, tf.train.latest_checkpoint(ckpt_model))
            graph = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, output_node)
            with tf.gfile.GFile(OUTPUT_GRAPH, "wb") as f:
                f.write(graph.SerializeToString())
            print('extract vector pb model saved!')
    
    
    def pb_2_savedmodel(pb_path="pb_model/my_model.pb", savedmodel_path="merge_savedmodel",
                        output_name=None):
        if output_name is None:
            output_name = ["loss/Softmax:0", "bert/pooler/dense/Tanh:0", "Mean:0", "loss/Softmax_1:0"]
        config = tf.ConfigProto(allow_soft_placement=True)
        sess = tf.Session(config=config)
        with gfile.FastGFile(pb_path, 'rb') as f:  # 加载冻结图模型文件
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
            sess.graph.as_default()
            tf.import_graph_def(graph_def, name='')  # 导入图定义
        sess.run(tf.global_variables_initializer())
        # 建立tensor info bundle
        input_ids = tf.saved_model.utils.build_tensor_info(sess.graph.get_tensor_by_name('input_ids:0'))
        input_mask = tf.saved_model.utils.build_tensor_info(sess.graph.get_tensor_by_name('input_mask:0'))
        output_1 = tf.saved_model.utils.build_tensor_info(sess.graph.get_tensor_by_name(output_name[0]))
        output_2 = tf.saved_model.utils.build_tensor_info(sess.graph.get_tensor_by_name(output_name[1]))
        output_3 = tf.saved_model.utils.build_tensor_info(sess.graph.get_tensor_by_name(output_name[2]))
        output_4 = tf.saved_model.utils.build_tensor_info(sess.graph.get_tensor_by_name(output_name[3]))
        export_path = os.path.join(tf.compat.as_bytes(savedmodel_path), tf.compat.as_bytes('1'))
        # Export model with signature
        builder = tf.saved_model.builder.SavedModelBuilder(export_path)
        prediction_signature = (
            tf.saved_model.signature_def_utils.build_signature_def(
                inputs={'input_ids': input_ids, 'input_mask': input_mask},
                outputs={'output_class': output_1, "output_cls_vector":output_2, "output_fl_vector":output_3, "output_confidence_class":output_4},
                method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME))
        builder.add_meta_graph_and_variables(
            sess, [tf.saved_model.tag_constants.SERVING],
            signature_def_map={
                'a_signature':
                    prediction_signature
            },
            main_op=tf.tables_initializer())
        builder.save()
        print('savedmodel 保存成功')
    
    if __name__ == '__main__':
        # ckpt_2_pb()
        # read_tfrecord()
        # create_20210119()
        # extract_bert_vector()
        bert_first_last_layer()
        pb_2_savedmodel()

     3)infer阶段,如果多个图会产生冲突,get graph时要有处理技巧。现将加载图的主干代码记录如下:

    class ConfidenceModel(object):
        """ model
        """
        def __init__(self):
            self.max_length = 350
            self.tokenizer = TOKENIZER
            self.out_graph = OUTPUT_GRAPH
            self.model_graph = {}
            graph = tf.Graph()
            with graph.as_default():
                self.model_graph['output_graph_def'] = tf.compat.v1.GraphDef()
                with open(self.out_graph, "rb") as f:
                    self.model_graph['output_graph_def'].ParseFromString(f.read())
                self.model_graph['sess'] = tf.Session(graph=graph)
            with self.model_graph['sess'].as_default():
                with graph.as_default():
                    self.model_graph['sess'].run(tf.compat.v1.global_variables_initializer())
                    _input_1, _input2, _output_1, _output_2, _output_3, _output_4 = tf.import_graph_def(
                        self.model_graph['output_graph_def'],
                        return_elements=[INPUT_1, INPUT_2, SOFTMAX_OUTPUT, FIRST_LAST_OUTPUT, CLS_OUTPUT, CONFIDENCE_OUTPUT])
                    self.input_ids_p = self.model_graph['sess'].graph.get_tensor_by_name("import/input_ids:0")
                    self.input_mask_p = self.model_graph['sess'].graph.get_tensor_by_name("import/input_mask:0")
                    self.output_1 = self.model_graph['sess'].graph.get_tensor_by_name("import/loss/Softmax:0")
                    self.output_2 = self.model_graph['sess'].graph.get_tensor_by_name("import/Mean:0")
                    self.output_3 = self.model_graph['sess'].graph.get_tensor_by_name("import/bert/pooler/dense/Tanh:0")
                    self.output_4 = self.model_graph['sess'].graph.get_tensor_by_name("import/loss/Softmax_1:0")
    时刻记着自己要成为什么样的人!
  • 相关阅读:
    [NOIP2020]T2字符串匹配
    【CSGRound2】逐梦者的初心(洛谷11月月赛 II & CSG Round 2 T3)
    【CF1225E Rock Is Push】推岩石
    [HAOI2016]食物链
    求先序排列
    图书管理员
    合并果子
    联合权值
    和为0的4个值
    玩具谜题
  • 原文地址:https://www.cnblogs.com/demo-deng/p/14787234.html
Copyright © 2011-2022 走看看