zoukankan      html  css  js  c++  java
  • Bert tensorflow 版本的线上预测demo

    在模型上线预测时,使用pb格式模型,确定输入tensor和输出tensor,明确对应的节点即可。以下代码是最近做的ner模型的infer部分,大家可以参照修改自己的模型

    import tensorflow as tf
    import os
    import pickle
    from bert_crf import tokenization
    
    model_dir = r'crf_output_bak/'
    output_graph = './pb_model/query_model.pb'
    bert_dir = r'chinese_L-12_H-768_A-12'
    
    # 加载label->id的词典
    with open(os.path.join(model_dir, 'label2id.pkl'), 'rb') as rf:
        label2id = pickle.load(rf)
        id2label = {value: key for key, value in label2id.items()}
    
    with open(os.path.join(model_dir, 'label_list.pkl'), 'rb') as rf:
        label_list = pickle.load(rf)
    num_labels = len(label_list)
    
    tokenizer = tokenization.FullTokenizer(
            vocab_file=os.path.join(bert_dir, 'vocab.txt'), do_lower_case=True)
    
    
    def load_pb_predict():
        """加载pb预测
        """
        text = ['深汕特别合作区']
        # print('input the test sentence:	{}'.format(sentence_all))
        # sentence = str(input())
        sentence = [[s for s in str(each)] for each in text]
        input_ids, input_mask, = convert(sentence)
    
        with tf.Graph().as_default():
            output_graph_def = tf.GraphDef()
            with open(output_graph, "rb") as f:
                output_graph_def.ParseFromString(f.read())
                tf.import_graph_def(output_graph_def, name="")
            res = [each.name for each in output_graph_def.node]
            for each in res:
                print(each)
            with tf.compat.v1.Session() as sess:
                sess.run(tf.compat.v1.global_variables_initializer())
                t1 = time.time()
                input_ids_p = sess.graph.get_tensor_by_name("input_ids:0")
                input_mask_p = sess.graph.get_tensor_by_name("input_mask:0")
                #
                feed_dict = {input_ids_p: input_ids,
                             input_mask_p: input_mask}
                # 定义输出的张量名称
                output_tensor_name = sess.graph.get_tensor_by_name("viterbi/ReverseSequence_1:0")
                out = sess.run(output_tensor_name, feed_dict)
                pred_label_result = convert_id_to_label(out, id2label)
                t2 = time.time()
                print('模型预测吞吐量:{}'.format((t2-t1)/len(input_ids)))
                print(pred_label_result)
    
    
    def convert_id_to_label(pred_ids_result, idx2label):
    
        result = []
        for row in range(len(pred_ids_result)):
            curr_seq = []
            for ids in pred_ids_result[row]:
                if ids == 0:
                    continue
                curr_label = idx2label[ids]
                if curr_label in ['[CLS]', '[SEP]']:
                    continue
                curr_seq.append(curr_label)
            result.append(curr_seq)
        return result
    
    
    def convert(samples):
        input_ids_list = []
        input_mask_list = []
        for line in samples:
            feature = convert_single_example(0, line, label_list, 25)
            input_ids_list.append(feature.input_ids)
            input_mask_list.append(feature.input_mask)
            # input_ids = np.reshape([feature.input_ids],(batch_size, FLAGS.max_seq_length))
            # input_mask = np.reshape([feature.input_mask],(batch_size, FLAGS.max_seq_length))
            # segment_ids = np.reshape([feature.segment_ids],(batch_size, FLAGS.max_seq_length))
            # label_ids =np.reshape([feature.label_ids],(batch_size, FLAGS.max_seq_length))
        return input_ids_list, input_mask_list
    
    
    def convert_single_example(ex_index, example, label_list, max_seq_length):
        """
        将一个样本进行分析,然后将字转化为id, 标签转化为id,然后结构化到InputFeatures对象中
        :param ex_index: index
        :param example: 一个样本
        :param label_list: 标签列表
        :param max_seq_length:
        :param tokenizer:
        :param mode:
        :return:
        """
        label_map = {}
        # 1表示从1开始对label进行index化
        for (i, label) in enumerate(label_list, 1):
            label_map[label] = i
        # 保存label->index 的map
        if not os.path.exists(os.path.join(model_dir, 'label2id.pkl')):
            with open(os.path.join(model_dir, 'label2id.pkl'), 'wb') as w:
                pickle.dump(label_map, w)
    
        tokens = example
        # tokens = .tokenize(example.text)
        # 序列截断
        if len(tokens) >= max_seq_length - 1:
            tokens = tokens[0:(max_seq_length - 2)]  # -2 的原因是因为序列需要加一个句首和句尾标志
        ntokens = []
        segment_ids = []
        label_ids = []
        ntokens.append("[CLS]")  # 句子开始设置CLS 标志
        segment_ids.append(0)
        # append("O") or append("[CLS]") not sure!
        label_ids.append("[CLS]")  # O OR CLS 没有任何影响,不过我觉得O 会减少标签个数,不过拒收和句尾使用不同的标志来标注,使用LCS 也没毛病
        for i, token in enumerate(tokens):
            ntokens.append(token)
            segment_ids.append(0)
            label_ids.append(0)
        ntokens.append("[SEP]")  # 句尾添加[SEP] 标志
        segment_ids.append(0)
        # append("O") or append("[SEP]") not sure!
        label_ids.append("[SEP]")
        input_ids = tokenizer.convert_tokens_to_ids(ntokens)  # 将序列中的字(ntokens)转化为ID形式
        input_mask = [1] * len(input_ids)
    
        # padding, 使用
        while len(input_ids) < max_seq_length:
            input_ids.append(0)
            input_mask.append(0)
            segment_ids.append(0)
            # we don't concerned about it!
            label_ids.append(0)
            ntokens.append("**NULL**")
            # label_mask.append(0)
        # print(len(input_ids))
        assert len(input_ids) == max_seq_length
        assert len(input_mask) == max_seq_length
        assert len(segment_ids) == max_seq_length
        assert len(label_ids) == max_seq_length
        # assert len(label_mask) == max_seq_length
    
        # 结构化为一个类
        feature = InputFeatures(
            input_ids=input_ids,
            input_mask=input_mask,
            segment_ids=segment_ids,
            label_ids=label_ids,
            # label_mask = label_mask
        )
        return feature
    
    
    class InputFeatures(object):
      """A single set of features of data."""
    
      def __init__(self,
                   input_ids,
                   input_mask,
                   segment_ids,
                   label_ids,
                   is_real_example=True):
        self.input_ids = input_ids
        self.input_mask = input_mask
        self.segment_ids = segment_ids
        self.label_ids = label_ids
        self.is_real_example = is_real_example
    
    
    if __name__ == '__main__':
        load_pb_predict()
  • 相关阅读:
    POJ 2018 二分
    873. Length of Longest Fibonacci Subsequence
    847. Shortest Path Visiting All Nodes
    838. Push Dominoes
    813. Largest Sum of Averages
    801. Minimum Swaps To Make Sequences Increasing
    790. Domino and Tromino Tiling
    764. Largest Plus Sign
    Weekly Contest 128
    746. Min Cost Climbing Stairs
  • 原文地址:https://www.cnblogs.com/demo-deng/p/13625161.html
Copyright © 2011-2022 走看看