zoukankan      html  css  js  c++  java
  • 用深度学习做命名实体识别(五)-模型使用

    通过本文,你将了解如何基于训练好的模型,来编写一个rest风格的命名实体提取接口,传入一个句子,接口会提取出句子中的人名、地址、组织、公司、产品、时间信息并返回。

    核心模块entity_extractor.py

    关键函数
    # 加载实体识别模型
    def person_model_init():
       ...
       
    # 预测句子中的实体
    def predict(sentence, labels_config, model_dir, batch_size, id2label, label_list, graph, input_ids_p, input_mask_p,
                pred_ids,
                tokenizer,
                sess, max_seq_length):
        ...
    
    完整代码
    # -*- coding: utf-8 -*-
    
    """
    基于模型的地址提取
    """
    __author__ = '程序员一一涤生'
    
    import codecs
    import os
    import pickle
    from datetime import datetime
    from pprint import pprint
    import numpy as np
    import tensorflow as tf
    from bert_base.bert import tokenization, modeling
    from bert_base.train.models import create_model, InputFeatures
    from bert_base.train.train_helper import get_args_parser
    
    args = get_args_parser()
    
    def convert(line, model_dir, label_list, tokenizer, batch_size, max_seq_length):
        feature = convert_single_example(model_dir, 0, line, label_list, max_seq_length, tokenizer, 'p')
        input_ids = np.reshape([feature.input_ids], (batch_size, max_seq_length))
        input_mask = np.reshape([feature.input_mask], (batch_size, max_seq_length))
        segment_ids = np.reshape([feature.segment_ids], (batch_size, max_seq_length))
        label_ids = np.reshape([feature.label_ids], (batch_size, max_seq_length))
        return input_ids, input_mask, segment_ids, label_ids
    
    def predict(sentence, labels_config, model_dir, batch_size, id2label, label_list, graph, input_ids_p, input_mask_p,
                pred_ids,
                tokenizer,
                sess, max_seq_length):
        with graph.as_default():
            start = datetime.now()
            # print(id2label)
            sentence = tokenizer.tokenize(sentence)
            # print('your input is:{}'.format(sentence))
            input_ids, input_mask, segment_ids, label_ids = convert(sentence, model_dir, label_list, tokenizer, batch_size,
                                                                    max_seq_length)
    
            feed_dict = {input_ids_p: input_ids,
                         input_mask_p: input_mask}
            # run session get current feed_dict result
            pred_ids_result = sess.run([pred_ids], feed_dict)
            pred_label_result = convert_id_to_label(pred_ids_result, id2label, batch_size)
            # print(pred_ids_result)
            print(pred_label_result)
            # todo: 组合策略
            result = strage_combined(sentence, pred_label_result[0], labels_config)
            print('time used: {} sec'.format((datetime.now() - start).total_seconds()))
        return result, pred_label_result
    
    def convert_id_to_label(pred_ids_result, idx2label, batch_size):
        """
        将id形式的结果转化为真实序列结果
        :param pred_ids_result:
        :param idx2label:
        :return:
        """
        result = []
        for row in range(batch_size):
            curr_seq = []
            for ids in pred_ids_result[row][0]:
                if ids == 0:
                    break
                curr_label = idx2label[ids]
                if curr_label in ['[CLS]', '[SEP]']:
                    continue
                curr_seq.append(curr_label)
            result.append(curr_seq)
        return result
    
    def strage_combined(tokens, tags, labels_config):
        """
        组合策略
        :param pred_label_result:
        :param types:
        :return:
        """
        def get_output(rs, data, type):
            words = []
            for i in data:
                words.append(str(i.word).replace("#", ""))
                # words.append(i.word)
            rs[type] = words
            return rs
        eval = Result(labels_config)
        if len(tokens) > len(tags):
            tokens = tokens[:len(tags)]
        labels_dict = eval.get_result(tokens, tags)
        arr = []
        for k, v in labels_dict.items():
            arr.append((k, v))
        rs = {}
        for item in arr:
            rs = get_output(rs, item[1], item[0])
        return rs
    
    def convert_single_example(model_dir, ex_index, example, label_list, max_seq_length, tokenizer, mode):
        """
        将一个样本进行分析,然后将字转化为id, 标签转化为id,然后结构化到InputFeatures对象中
        :param ex_index: index
        :param example: 一个样本
        :param label_list: 标签列表
        :param max_seq_length:
        :param tokenizer:
        :param mode:
        :return:
        """
        label_map = {}
        # 1表示从1开始对label进行index化
        for (i, label) in enumerate(label_list, 1):
            label_map[label] = i
        # 保存label->index 的map
        if not os.path.exists(os.path.join(model_dir, 'label2id.pkl')):
            with codecs.open(os.path.join(model_dir, 'label2id.pkl'), 'wb') as w:
                pickle.dump(label_map, w)
        tokens = example
        # tokens = tokenizer.tokenize(example.text)
        # 序列截断
        if len(tokens) >= max_seq_length - 1:
            tokens = tokens[0:(max_seq_length - 2)]  # -2 的原因是因为序列需要加一个句首和句尾标志
        ntokens = []
        segment_ids = []
        label_ids = []
        ntokens.append("[CLS]")  # 句子开始设置CLS 标志
        segment_ids.append(0)
        # append("O") or append("[CLS]") not sure!
        label_ids.append(label_map["[CLS]"])  # O OR CLS 没有任何影响,不过我觉得O 会减少标签个数,不过拒收和句尾使用不同的标志来标注,使用LCS 也没毛病
        for i, token in enumerate(tokens):
            ntokens.append(token)
            segment_ids.append(0)
            label_ids.append(0)
        ntokens.append("[SEP]")  # 句尾添加[SEP] 标志
        segment_ids.append(0)
        # append("O") or append("[SEP]") not sure!
        label_ids.append(label_map["[SEP]"])
        input_ids = tokenizer.convert_tokens_to_ids(ntokens)  # 将序列中的字(ntokens)转化为ID形式
        input_mask = [1] * len(input_ids)
        # padding, 使用
        while len(input_ids) < max_seq_length:
            input_ids.append(0)
            input_mask.append(0)
            segment_ids.append(0)
            # we don't concerned about it!
            label_ids.append(0)
            ntokens.append("**NULL**")
            # label_mask.append(0)
        # print(len(input_ids))
        assert len(input_ids) == max_seq_length
        assert len(input_mask) == max_seq_length
        assert len(segment_ids) == max_seq_length
        assert len(label_ids) == max_seq_length
        # assert len(label_mask) == max_seq_length
        # 结构化为一个类
        feature = InputFeatures(
            input_ids=input_ids,
            input_mask=input_mask,
            segment_ids=segment_ids,
            label_ids=label_ids,
            # label_mask = label_mask
        )
        return feature
    
    class Pair(object):
        def __init__(self, word, start, end, type, merge=False):
            self.__word = word
            self.__start = start
            self.__end = end
            self.__merge = merge
            self.__types = type
    
        @property
        def start(self):
            return self.__start
    
        @property
        def end(self):
            return self.__end
    
        @property
        def merge(self):
            return self.__merge
    
        @property
        def word(self):
            return self.__word
    
        @property
        def types(self):
            return self.__types
    
        @word.setter
        def word(self, word):
            self.__word = word
    
        @start.setter
        def start(self, start):
            self.__start = start
    
        @end.setter
        def end(self, end):
            self.__end = end
    
        @merge.setter
        def merge(self, merge):
            self.__merge = merge
    
        @types.setter
        def types(self, type):
            self.__types = type
    
        def __str__(self) -> str:
            line = []
            line.append('entity:{}'.format(self.__word))
            line.append('start:{}'.format(self.__start))
            line.append('end:{}'.format(self.__end))
            line.append('merge:{}'.format(self.__merge))
            line.append('types:{}'.format(self.__types))
            return '	'.join(line)
    
    class Result(object):
        def __init__(self, labels_config):
            self.others = []
            self.labels_config = labels_config
            self.labels = {}
            for la in self.labels_config:
                self.labels[la] = []
    
        def get_result(self, tokens, tags):
            # 先获取标注结果
            self.result_to_json(tokens, tags)
            return self.labels
    
        def result_to_json(self, string, tags):
            """
            将模型标注序列和输入序列结合 转化为结果
            :param string: 输入序列
            :param tags: 标注结果
            :return:
            """
            item = {"entities": []}
            entity_name = ""
            entity_start = 0
            idx = 0
            last_tag = ''
    
            for char, tag in zip(string, tags):
                if tag[0] == "S":
                    self.append(char, idx, idx + 1, tag[2:])
                    item["entities"].append({"word": char, "start": idx, "end": idx + 1, "type": tag[2:]})
                elif tag[0] == "B":
                    if entity_name != '':
                        self.append(entity_name, entity_start, idx, last_tag[2:])
                        item["entities"].append(
                            {"word": entity_name, "start": entity_start, "end": idx, "type": last_tag[2:]})
                        entity_name = ""
                    entity_name += char
                    entity_start = idx
                elif tag[0] == "I":
                    entity_name += char
                elif tag[0] == "O":
                    if entity_name != '':
                        self.append(entity_name, entity_start, idx, last_tag[2:])
                        item["entities"].append(
                            {"word": entity_name, "start": entity_start, "end": idx, "type": last_tag[2:]})
                        entity_name = ""
                else:
                    entity_name = ""
                    entity_start = idx
                idx += 1
                last_tag = tag
            if entity_name != '':
                self.append(entity_name, entity_start, idx, last_tag[2:])
                item["entities"].append({"word": entity_name, "start": entity_start, "end": idx, "type": last_tag[2:]})
            return item
    
        def append(self, word, start, end, tag):
            if tag in self.labels_config:
                self.labels[tag].append(Pair(word, start, end, tag))
            else:
                self.others.append(Pair(word, start, end, tag))
    
    def person_model_init():
        return model_init("person")
    
    def model_init(model_name):
        if os.name == 'nt':  # windows path config
            model_dir = 'E:/quickstart/deeplearning/nlp_demo/%s/model' % model_name
            bert_dir = 'E:/quickstart/deeplearning/nlp_demo/bert_model_info/chinese_L-12_H-768_A-12'
        else:  # linux path config
            model_dir = '/home/yjy/project/deeplearning/nlp_demo/%s/model' % model_name
            bert_dir = '/home/yjy/project/deeplearning/nlp_demo/bert_model_info/chinese_L-12_H-768_A-12'
    
        batch_size = 1
        max_seq_length = 500
    
        print('checkpoint path:{}'.format(os.path.join(model_dir, "checkpoint")))
        if not os.path.exists(os.path.join(model_dir, "checkpoint")):
            raise Exception("failed to get checkpoint. going to return ")
    
        # 加载label->id的词典
        with codecs.open(os.path.join(model_dir, 'label2id.pkl'), 'rb') as rf:
            label2id = pickle.load(rf)
            id2label = {value: key for key, value in label2id.items()}
    
        with codecs.open(os.path.join(model_dir, 'label_list.pkl'), 'rb') as rf:
            label_list = pickle.load(rf)
        num_labels = len(label_list) + 1
    
        gpu_config = tf.ConfigProto()
        gpu_config.gpu_options.allow_growth = True
        graph = tf.Graph()
        sess = tf.Session(graph=graph, config=gpu_config)
    
        with graph.as_default():
            print("going to restore checkpoint")
            # sess.run(tf.global_variables_initializer())
            input_ids_p = tf.placeholder(tf.int32, [batch_size, max_seq_length], name="input_ids")
            input_mask_p = tf.placeholder(tf.int32, [batch_size, max_seq_length], name="input_mask")
    
            bert_config = modeling.BertConfig.from_json_file(os.path.join(bert_dir, 'bert_config.json'))
            (total_loss, logits, trans, pred_ids) = create_model(
                bert_config=bert_config, is_training=False, input_ids=input_ids_p, input_mask=input_mask_p,
                segment_ids=None,
                labels=None, num_labels=num_labels, use_one_hot_embeddings=False, dropout_rate=1.0)
    
            saver = tf.train.Saver()
            saver.restore(sess, tf.train.latest_checkpoint(model_dir))
    
        tokenizer = tokenization.FullTokenizer(
            vocab_file=os.path.join(bert_dir, 'vocab.txt'), do_lower_case=args.do_lower_case)
    
        return model_dir, batch_size, id2label, label_list, graph, input_ids_p, input_mask_p, pred_ids, tokenizer, sess, max_seq_length
    
    
    if __name__ == "__main__":
        _model_dir, _batch_size, _id2label, _label_list, _graph, _input_ids_p, _input_mask_p, _pred_ids, _tokenizer, _sess, _max_seq_length = person_model_init()
        PERSON_LABELS = ["TIME", "LOCATION", "PERSON_NAME", "ORG_NAME", "COMPANY_NAME", "PRODUCT_NAME"]
        while True:
            print('input the test sentence:')
            _sentence = str(input())
            pred_rs, pred_label_result = predict(_sentence, PERSON_LABELS, _model_dir, _batch_size, _id2label, _label_list,
                                                 _graph,
                                                 _input_ids_p,
                                                 _input_mask_p, _pred_ids, _tokenizer, _sess, _max_seq_length)
            pprint(pred_rs)
    

    编写rest风格的接口

    我们将采用python的flask框架来提供rest接口。

    首先,新建一个python项目,项目根路径下放入以下目录和文件:

    • bert_base目录及文件、bert_model_info目录及文件在上一篇文章 用深度学习做命名实体识别(四)——模型训练 给出的云盘项目中可以找到;
    • person目录下的model就是我们在上一篇文章中训练得到的命名实体识别模型以及一些附属文件,在项目的output目录下可以得到。
    然后,创建启动文件nlp_main.py,内容如下:
    # -*- coding: utf-8 -*-
    
    """
    flask 入口
    """
    import os
    import nlp_config as nc
    from flaskr import create_app, loadProjContext
    
    __author__ = '程序员一一涤生'
    
    from flask import jsonify, make_response, redirect
    
    # 加载flask配置信息
    # app = create_app('config.DevelopmentConfig')
    app = create_app(nc.config['default'])
    # 加载项目上下文信息
    loadProjContext()
    
    @app.errorhandler(404)
    def not_found(error):
        return make_response(jsonify({'error': 'Not found'}), 404)
    
    @app.errorhandler(400)
    def not_found(error):
        return make_response(jsonify({'error': '400 Bad Request,参数或参数内容异常'}), 400)
    
    @app.route('/')
    def index_sf():
        # return render_template('index.html')
        return redirect('index.html')
    
    if __name__ == '__main__':
        app.run('localhost', 5006, app, use_reloader=False)
    
    接着,创建本flask项目的初始化文件flaskr.py,用于启动项目的时候预设置和加载一些信息,内容如下:
    
    # -*- coding: utf-8 -*-
    """
    flask初始化
    """
    from logging.config import dictConfig
    from flask import Flask
    from flask_cors import CORS
    import person_ner_resource
    from entity_extractor import person_model_init
    from person_ner_resource import person
    
    __author__ = '程序员一一涤生'
    
    def create_app(config_type):
        dictConfig({
            'version': 1,
            'formatters': {'default': {
                'format': '[%(asctime)s] %(name)s %(levelname)s in %(module)s %(lineno)d: %(message)s',
            }},
            'handlers': {'wsgi': {
                'class': 'logging.StreamHandler',
                'stream': 'ext://flask.logging.wsgi_errors_stream',
                'formatter': 'default'
            }},
            'root': {
                'level': 'DEBUG',
                # 'level': 'WARN',
                # 'level': 'INFO',
                'handlers': ['wsgi']
            }
        })
        # 加载flask配置信息
        app = Flask(__name__, static_folder='static', static_url_path='')
        # CORS(app, resources=r'/*',origins=['192.168.1.104'])  # r'/*' 是通配符,允许跨域请求本服务器所有的URL,"origins": '*'表示允许所有ip跨域访问本服务器的url
        CORS(app, resources={r"/*": {"origins": '*'}})  # r'/*' 是通配符,允许跨域请求本服务器所有的URL,"origins": '*'表示允许所有ip跨域访问本服务器的url
        app.config.from_object(config_type)
        app.register_blueprint(person, url_prefix='/person')
        # 初始化上下文
        ctx = app.app_context()
        ctx.push()
        return app
    
    def loadProjContext():
        # 加载人名提取模型
        model_dir, batch_size, id2label, label_list, graph, input_ids_p, input_mask_p, pred_ids, tokenizer, sess, max_seq_length = person_model_init()
        person_ner_resource.model_dir = model_dir
        person_ner_resource.batch_size = batch_size
        person_ner_resource.id2label = id2label
        person_ner_resource.label_list = label_list
        person_ner_resource.graph = graph
        person_ner_resource.input_ids_p = input_ids_p
        person_ner_resource.input_mask_p = input_mask_p
        person_ner_resource.pred_ids = pred_ids
        person_ner_resource.tokenizer = tokenizer
        person_ner_resource.sess = sess
        person_ner_resource.max_seq_length = max_seq_length
    
    
    然后,创建配置文件nlp_config.py,用于切换生产、开发、测试环境,内容如下:
    # -*- coding: utf-8 -*-
    
    """
    本模块是Flask的配置模块
    """
    import os
    
    __author__ = '程序员一一涤生'
    
    basedir = os.path.abspath(os.path.dirname(__file__))
    
    class BaseConfig:  # 基本配置类
        SECRET_KEY = b'xe4rx04xb5xb2x00xf1xadfxa3xf3Vx03xc5x9fx82$^xa25Oxf0Rxda'
        JSONIFY_MIMETYPE = 'application/json; charset=utf-8'  # 默认JSONIFY_MIMETYPE的配置是不带'; charset=utf-8的'
        JSON_AS_ASCII = False  # 若不关闭,使用JSONIFY返回json时中文会显示为Unicode字符
        ENCODING = 'utf-8'
    
        # 自定义的配置项
        PERSON_LABELS = ["TIME", "LOCATION", "PERSON_NAME", "ORG_NAME", "COMPANY_NAME", "PRODUCT_NAME"]
    
    class DevelopmentConfig(BaseConfig):
        ENV = 'development'
        DEBUG = True
    
    class TestingConfig(BaseConfig):
        TESTING = True
        WTF_CSRF_ENABLED = False
    
    class ProductionConfig(BaseConfig):
        DEBUG = False
    
    config = {
        'testing': TestingConfig,
        'default': DevelopmentConfig
        # 'default': ProductionConfig
    }
    
    接着,创建人名识别接口文件person_ner_resource.py,内容如下:
    
    # -*- coding: utf-8 -*-
    
    """
    命名实体识别接口
    """
    from entity_extractor import predict
    
    __author__ = '程序员一一涤生'
    
    from flask import Blueprint, make_response, request, current_app
    from flask import jsonify
    person = Blueprint('person', __name__)
    
    model_dir, batch_size, id2label, label_list, graph, input_ids_p, input_mask_p, pred_ids, tokenizer, sess, max_seq_length = None, None, None, None, None, None, None, None, None, None, None
    @person.route('/extract', methods=['POST'])
    
    def extract():
        params = request.get_json()
        if 't' not in params or params['t'] is None or len(params['t']) > 500 or len(params['t']) < 2:
            return make_response(jsonify({'error': '文本长度不符合要求,长度限制:2~500'}), 400)
        sentence = params['t']
        # 成句
        sentence = sentence + "。" if not sentence.endswith((",", "。", "!", "?")) else sentence
        # 利用模型提取
        pred_rs, pred_label_result = predict(sentence, current_app.config['PERSON_LABELS'], model_dir, batch_size, id2label,
                                             label_list, graph, input_ids_p,
                                             input_mask_p,
                                             pred_ids, tokenizer, sess, max_seq_length)
        print(sentence)
        return jsonify(pred_rs)
    
    if __name__ == '__main__':
        pass
    
    接着,将requirements.txt文件放到项目根路径下,文件内容如下:
    absl-py==0.7.0
    astor==0.7.1
    backcall==0.1.0
    backports.weakref==1.0rc1
    bleach==1.5.0
    certifi==2016.2.28
    click==6.7
    colorama==0.4.1
    colorful==0.5.0
    decorator==4.3.2
    defusedxml==0.5.0
    entrypoints==0.3
    Flask==1.0.2
    Flask-Cors==3.0.3
    gast==0.2.2
    grpcio==1.18.0
    h5py==2.9.0
    html5lib==0.9999999
    ipykernel==5.1.0
    ipython==7.2.0
    ipython-genutils==0.2.0
    ipywidgets==7.4.2
    itsdangerous==0.24
    jedi==0.13.2
    Jinja2==2.10
    jsonschema==2.6.0
    jupyter==1.0.0
    jupyter-client==5.2.4
    jupyter-console==6.0.0
    jupyter-core==4.4.0
    Keras-Applications==1.0.6
    Keras-Preprocessing==1.0.5
    Markdown==3.0.1
    MarkupSafe==1.1.0
    mistune==0.8.4
    mock==3.0.5
    nbconvert==5.4.0
    nbformat==4.4.0
    notebook==5.7.4
    numpy==1.16.0
    pandocfilters==1.4.2
    parso==0.3.2
    pickleshare==0.7.5
    prettyprinter==0.17.0
    prometheus-client==0.5.0
    prompt-toolkit==2.0.8
    protobuf==3.6.1
    Pygments==2.3.1
    python-dateutil==2.7.5
    pywinpty==0.5.5
    pyzmq==17.1.2
    qtconsole==4.4.3
    Send2Trash==1.5.0
    six==1.12.0
    tensorboard==1.13.1
    tensorflow==1.13.1
    tensorflow-estimator==1.13.0
    termcolor==1.1.0
    terminado==0.8.1
    testpath==0.4.2
    tornado==5.1.1
    traitlets==4.3.2
    wcwidth==0.1.7
    Werkzeug==0.14.1
    widgetsnbextension==3.4.2
    wincertstore==0.2
    
    然后,执行如下命令,安装requirements.txt中的包:
    pip install -i https://pypi.tuna.tsinghua.edu.cn/simple -r requirements.txt
    

    以上步骤完成后,我们就可以尝试启动项目了。

    启动项目

    运行如下命令,启动该flask项目:

    python nlp_main.py
    

    调用接口

    本文使用postman来调用命名实体提取接口,接口地址:

    http://localhost:5006/person/extract

    调用效果展示:

    注意,在cpu上使用模型的时间大概在2到3秒,而如果项目部署在搭载了支持深度学习的GPU的电脑上,接口的返回会快很多很多,当然不要忘记将tensorflow改为安装tensorflow-gpu。

    本篇就这么多内容,到此,我们已经基于深度学习开发了一个可以从自然语言中提取出人名、地址、组织、公司、产品、时间的项目,从下一篇开始,我们将介绍本项目使用的深度学习算法Bertcrf,通过对算法的了解,我们将更好的理解为什么模型能够准确的从句子中提取出我们想要的实体。

    ok,本篇就这么多内容啦,感谢阅读O(∩_∩)O,88

    本博客内容来自公众号“程序员一一涤生”,欢迎扫码关注 o(∩_∩)o

  • 相关阅读:
    树莓派pwm驱动好盈电调及伺服电机
    wiringPi库的pwm配置及使用说明
    未能加载文件或程序集“**, Version=1.0.0.0, Culture=neutral, PublicKeyToken=null”或它的某一个依赖项。试图加载格式不正确的程序。
    poj 1700 Crossing River(贪心)
    前缀 树 背单词(Remember the Word,LA 3942)
    c/c++ double的数字 转成字符串后 可以有效的避免精度要求不高的数
    hdu acm 2154(多解取一解)
    hdu 5104 Primes Problem(prime 将三重循环化两重)
    hdu 2203亲和串 (kmp)
    hdu 2519 新生晚会 排列组合
  • 原文地址:https://www.cnblogs.com/anai/p/11571812.html
Copyright © 2011-2022 走看看