zoukankan      html  css  js  c++  java
  • BERT模型+rabbitmq队列,进行实时预测,防止每次预测都重新加载图

    1. 创建一个新类,使用tensorflow内置的from_generator函数,通过生成器传入句子,生成器中使用while循环,通过channel获取rabbitmq的句子进行预测

    # 用于实时预测的一个类,内置了rabbitmq消息队列,由消息队列传入预测句子,最终实时打印出预测结果
    class BertPredictByGen(object):
        def __init__(self, estimator, label_list, tokenizer, channel, queue_name):
            self.estimator = estimator
            self.label_list = label_list
            self.tokenizer = tokenizer
            self.channel = channel
            self.queue_name = queue_name
    
        def input_fn_builder2(self):
            def gen():
                while True:
                    method, properties, qs = self.channel.basic_get(self.queue_name, auto_ack=False)
                    if not qs:
                        continue
                    self.channel.basic_ack(delivery_tag=method.delivery_tag)  # 应答
                    text = str(qs, encoding='UTF-8')
                    # guid这里实际用不到,可以随便写,但是label必须是label_list中的一个
                    examples = [InputExample(guid=0, text_a=text, text_b=None, label="其他")]
                    features = convert_examples_to_features(examples, self.label_list, FLAGS.max_seq_length,
                                                            self.tokenizer)
                    all_input_ids = []
                    all_input_mask = []
                    all_segment_ids = []
                    all_label_ids = []
    
                    for feature in features:
                        all_input_ids.append(feature.input_ids)
                        all_input_mask.append(feature.input_mask)
                        all_segment_ids.append(feature.segment_ids)
                        all_label_ids.append(feature.label_id)
    
                    yield {
                           'input_ids': all_input_ids,
                           'input_mask': all_input_mask,
                           'segment_ids': all_segment_ids,
                           'label_ids': all_label_ids,
                           }
    
            def input_fn(params):
                # batch_size = params["batch_size"]
                types = {
                         'input_ids': tf.int32,
                         'input_mask': tf.int32,
                         'segment_ids': tf.int32,
                         'label_ids': tf.int32,
                         }
                shapes = {
                          'input_ids': (None, FLAGS.max_seq_length),
                          'input_mask': (None, FLAGS.max_seq_length),
                          'segment_ids': (None, FLAGS.max_seq_length),
                          'label_ids': (None,),
                          }
                return tf.data.Dataset.from_generator(gen, output_types=types, output_shapes=shapes).prefetch(1)
    
            return input_fn
    
        def predict(self):
            for result in self.estimator.predict(self.input_fn_builder2(), yield_single_examples=False):
                answer = self.label_list[np.argmax(result['probabilities'])]  # 预测结果
                print("raw result:", answer)

    2. 原本源码的do_predict函数改成如下:

    if FLAGS.do_predict:
            project_config = modeling.BertConfig.from_json_file(FLAGS.project_config_file)  # 加载项目配置文件,自己按照bert_config_file写一个配置文件,用于存储rabbitmq的配置
            credentials = pika.PlainCredentials(username=project_config.queue_username, password=project_config.queue_password)
            connection = pika.BlockingConnection(
                pika.ConnectionParameters(host=project_config.queue_host, virtual_host=project_config.queue_virtual_host,
                                          credentials=credentials))
            channel = connection.channel()  # 创建频道
    
            classifer = BertPredictByGen(estimator=estimator, label_list=label_list, tokenizer=tokenizer, channel=channel,
                                         queue_name=project_config.queue_name)  # 实例化类
            classifer.predict()  # 进行预测
  • 相关阅读:
    Ext的组件结构分析(转)
    分析模式 责任模式
    Nhibernate学习起步之manytoone篇(转 明了篇)
    企业开发框架NHibernate和Spring.Net简介3
    企业开发框架NHibernate和Spring.Net简介4
    NHibernate Cascades: the different between all, alldeleteorphans and saveupdate
    XML与数据库
    企业开发框架NHibernate和Spring.Net简介1
    python对函数的理解
    seleniumwebdriver(python) (十五) 鼠标事件
  • 原文地址:https://www.cnblogs.com/yiduobaozhiblog1/p/15576399.html
Copyright © 2011-2022 走看看