zoukankan      html  css  js  c++  java
  • Bert tensorflow 版本的线上预测demo

    在模型上线预测时,使用pb格式模型,确定输入tensor和输出tensor,明确对应的节点即可。以下代码是最近做的ner模型的infer部分,大家可以参照修改自己的模型

    import tensorflow as tf
    import os
    import pickle
    from bert_crf import tokenization
    
    model_dir = r'crf_output_bak/'
    output_graph = './pb_model/query_model.pb'
    bert_dir = r'chinese_L-12_H-768_A-12'
    
    # 加载label->id的词典
    with open(os.path.join(model_dir, 'label2id.pkl'), 'rb') as rf:
        label2id = pickle.load(rf)
        id2label = {value: key for key, value in label2id.items()}
    
    with open(os.path.join(model_dir, 'label_list.pkl'), 'rb') as rf:
        label_list = pickle.load(rf)
    num_labels = len(label_list)
    
    tokenizer = tokenization.FullTokenizer(
            vocab_file=os.path.join(bert_dir, 'vocab.txt'), do_lower_case=True)
    
    
    def load_pb_predict():
        """加载pb预测
        """
        text = ['深汕特别合作区']
        # print('input the test sentence:	{}'.format(sentence_all))
        # sentence = str(input())
        sentence = [[s for s in str(each)] for each in text]
        input_ids, input_mask, = convert(sentence)
    
        with tf.Graph().as_default():
            output_graph_def = tf.GraphDef()
            with open(output_graph, "rb") as f:
                output_graph_def.ParseFromString(f.read())
                tf.import_graph_def(output_graph_def, name="")
            res = [each.name for each in output_graph_def.node]
            for each in res:
                print(each)
            with tf.compat.v1.Session() as sess:
                sess.run(tf.compat.v1.global_variables_initializer())
                t1 = time.time()
                input_ids_p = sess.graph.get_tensor_by_name("input_ids:0")
                input_mask_p = sess.graph.get_tensor_by_name("input_mask:0")
                #
                feed_dict = {input_ids_p: input_ids,
                             input_mask_p: input_mask}
                # 定义输出的张量名称
                output_tensor_name = sess.graph.get_tensor_by_name("viterbi/ReverseSequence_1:0")
                out = sess.run(output_tensor_name, feed_dict)
                pred_label_result = convert_id_to_label(out, id2label)
                t2 = time.time()
                print('模型预测吞吐量:{}'.format((t2-t1)/len(input_ids)))
                print(pred_label_result)
    
    
    def convert_id_to_label(pred_ids_result, idx2label):
    
        result = []
        for row in range(len(pred_ids_result)):
            curr_seq = []
            for ids in pred_ids_result[row]:
                if ids == 0:
                    continue
                curr_label = idx2label[ids]
                if curr_label in ['[CLS]', '[SEP]']:
                    continue
                curr_seq.append(curr_label)
            result.append(curr_seq)
        return result
    
    
    def convert(samples):
        input_ids_list = []
        input_mask_list = []
        for line in samples:
            feature = convert_single_example(0, line, label_list, 25)
            input_ids_list.append(feature.input_ids)
            input_mask_list.append(feature.input_mask)
            # input_ids = np.reshape([feature.input_ids],(batch_size, FLAGS.max_seq_length))
            # input_mask = np.reshape([feature.input_mask],(batch_size, FLAGS.max_seq_length))
            # segment_ids = np.reshape([feature.segment_ids],(batch_size, FLAGS.max_seq_length))
            # label_ids =np.reshape([feature.label_ids],(batch_size, FLAGS.max_seq_length))
        return input_ids_list, input_mask_list
    
    
    def convert_single_example(ex_index, example, label_list, max_seq_length):
        """
        将一个样本进行分析,然后将字转化为id, 标签转化为id,然后结构化到InputFeatures对象中
        :param ex_index: index
        :param example: 一个样本
        :param label_list: 标签列表
        :param max_seq_length:
        :param tokenizer:
        :param mode:
        :return:
        """
        label_map = {}
        # 1表示从1开始对label进行index化
        for (i, label) in enumerate(label_list, 1):
            label_map[label] = i
        # 保存label->index 的map
        if not os.path.exists(os.path.join(model_dir, 'label2id.pkl')):
            with open(os.path.join(model_dir, 'label2id.pkl'), 'wb') as w:
                pickle.dump(label_map, w)
    
        tokens = example
        # tokens = .tokenize(example.text)
        # 序列截断
        if len(tokens) >= max_seq_length - 1:
            tokens = tokens[0:(max_seq_length - 2)]  # -2 的原因是因为序列需要加一个句首和句尾标志
        ntokens = []
        segment_ids = []
        label_ids = []
        ntokens.append("[CLS]")  # 句子开始设置CLS 标志
        segment_ids.append(0)
        # append("O") or append("[CLS]") not sure!
        label_ids.append("[CLS]")  # O OR CLS 没有任何影响,不过我觉得O 会减少标签个数,不过拒收和句尾使用不同的标志来标注,使用LCS 也没毛病
        for i, token in enumerate(tokens):
            ntokens.append(token)
            segment_ids.append(0)
            label_ids.append(0)
        ntokens.append("[SEP]")  # 句尾添加[SEP] 标志
        segment_ids.append(0)
        # append("O") or append("[SEP]") not sure!
        label_ids.append("[SEP]")
        input_ids = tokenizer.convert_tokens_to_ids(ntokens)  # 将序列中的字(ntokens)转化为ID形式
        input_mask = [1] * len(input_ids)
    
        # padding, 使用
        while len(input_ids) < max_seq_length:
            input_ids.append(0)
            input_mask.append(0)
            segment_ids.append(0)
            # we don't concerned about it!
            label_ids.append(0)
            ntokens.append("**NULL**")
            # label_mask.append(0)
        # print(len(input_ids))
        assert len(input_ids) == max_seq_length
        assert len(input_mask) == max_seq_length
        assert len(segment_ids) == max_seq_length
        assert len(label_ids) == max_seq_length
        # assert len(label_mask) == max_seq_length
    
        # 结构化为一个类
        feature = InputFeatures(
            input_ids=input_ids,
            input_mask=input_mask,
            segment_ids=segment_ids,
            label_ids=label_ids,
            # label_mask = label_mask
        )
        return feature
    
    
    class InputFeatures(object):
      """A single set of features of data."""
    
      def __init__(self,
                   input_ids,
                   input_mask,
                   segment_ids,
                   label_ids,
                   is_real_example=True):
        self.input_ids = input_ids
        self.input_mask = input_mask
        self.segment_ids = segment_ids
        self.label_ids = label_ids
        self.is_real_example = is_real_example
    
    
    if __name__ == '__main__':
        load_pb_predict()
  • 相关阅读:
    C# 获取程序当前路径
    主线程等待子线程执行二
    ADO.NET Entity Framework Code Fisrt 开篇(一)
    解决Eclipse java was started but returned exit code = 1问题
    windows 下YII框架初试
    hadoop的partitioner
    YII 学习一: YII 初试
    linux 文件大小ll和du不一致问题
    [转载]PyDev for Eclipse 简介
    python 常用代码学习笔记之commands模块
  • 原文地址:https://www.cnblogs.com/demo-deng/p/13625161.html
Copyright © 2011-2022 走看看