zoukankan      html  css  js  c++  java
  • 【NLP】使用bert

    # 参考 https://blog.csdn.net/luoyexuge/article/details/84939755 小做改动

    需要:

      github上下载bert的代码:https://github.com/google-research/bert

      下载google训练好的中文语料模型:https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip

    使用:

      使用bert,其实是使用几个checkpoint(ckpt)文件。上面下载的zip是google训练好的bert,我们可以在那个zip内的ckpt文件基础上继续训练,获得更贴近具体任务的ckpt文件。

     如果是直接使用训练好的ckpt文件(就是bert模型),只需如下代码,定义model,获得model的值

    from bert import modeling    
    # 使用数据加载BertModel,获取对应的字embedding model = modeling.BertModel( config=bert_config, is_training=is_training, input_ids=input_ids, input_mask=input_mask, token_type_ids=segment_ids, use_one_hot_embeddings=use_one_hot_embeddings ) # 获取对应的embedding 输入数据[batch_size, seq_length, embedding_size] embedding = model.get_sequence_output()

    这里的bert_config 是之前定义的bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file);输入是input_ids, input_mask, segment_ids三个向量;还有两个设置is_training(False), use_one_hot_embedding(False),这样的设置还有很多,这里只列举这两个。。

    关于FLAGS,需要提到TensorFlow的flags,相当于配置运行变量,设置如下:

    import tensorflow as tf
    
    flags = tf.flags
    FLAGS = flags.FLAGS
    
    # 预训练的中文model路径和项目路径
    bert_path = '/home/xiangbo_wang/xiangbo/NER/chinese_L-12_H-768_A-12/'
    root_path = '/home/xiangbo_wang/xiangbo/NER/BERT-BiLSTM-CRF-NER'
    
    # 设置bert_config_file
    flags.DEFINE_string(
        "bert_config_file", os.path.join(bert_path, 'bert_config.json'),
        "The config json file corresponding to the pre-trained BERT model."
    )

     关于输入的三个向量,具体内容可以参照之前的博客https://www.cnblogs.com/rucwxb/p/10277217.html

    input_ids, segment_ids 分别是 token embedding, segment embedding

    position embedding会自动生成

    input_mask 是input中需要mask的位置,本来是随机取一部分,这里的做法是把全部输入位置都mask住。

    获得输入的这三个向量的方式如下:

    # 获得三个向量的函数
    def inputs(vectors,maxlen=10):
        length=len(vectors)
        if length>=maxlen:
            return  vectors[0:maxlen],[1]*maxlen,[0]*maxlen
        else:
            input=vectors+[0]*(maxlen-length)
            mask=[1]*length+[0]*(maxlen-length)
            segment=[0]*maxlen
            return input,mask,segment
    
    # 测试的句子
    text = request.args.get('text')
    vectors = [di.get("[CLS]")] + [di.get(i) if i in di else di.get("[UNK]") for i in list(text)] + [di.get("[SEP]")]
    
    # 转成1*maxlen的向量
    input, mask, segment = inputs(vectors)
    input_ids = np.reshape(np.array(input), [1, -1])
    input_mask = np.reshape(np.array(mask), [1, -1])
    segment_ids = np.reshape(np.array(segment), [1, -1])

    最后是将变量输入模型获得最终的bert向量:

    # 定义输入向量形状
    input_ids_p=tf.placeholder(shape=[None,None],dtype=tf.int32,name="input_ids_p")
    input_mask_p=tf.placeholder(shape=[None,None],dtype=tf.int32,name="input_mask_p")
    segment_ids_p=tf.placeholder(shape=[None,None],dtype=tf.int32,name="segment_ids_p")
     
    model = modeling.BertModel(
            config=bert_config,
            is_training=is_training,
            input_ids=input_ids_p,
            input_mask=input_mask_p,
            token_type_ids=segment_ids_p,
            use_one_hot_embeddings=use_one_hot_embeddings
        )
    
    # 载入预训练模型
    restore_saver = tf.train.Saver()
    restore_saver.restore(sess, init_checkpoint)
    
    # 一个[batch_size, seq_length, embedding_size]大小的向量
    embedding = tf.squeeze(model.get_sequence_output())
    # 运行结果
    ret=sess.run(embedding,feed_dict={"input_ids_p:0":input_ids,"input_mask_p:0":input_mask,"segment_ids_p:0":segment_ids})

    完整可运行代码如下:

    import tensorflow as tf 
    from bert import modeling
    import collections
    import os
    import numpy as np 
    import json
    
    flags = tf.flags
    FLAGS = flags.FLAGS
    bert_path = '/home/xiangbo_wang/xiangbo/NER/chinese_L-12_H-768_A-12/'
    
    flags.DEFINE_string(
        'bert_config_file', os.path.join(bert_path, 'bert_config.json'),
        'config json file corresponding to the pre-trained BERT model.'
    )
    flags.DEFINE_string(
        'bert_vocab_file', os.path.join(bert_path,'vocab.txt'),
        'the config vocab file',
    )
    flags.DEFINE_string(
        'init_checkpoint', os.path.join(bert_path,'bert_model.ckpt'),
        'from a pre-trained BERT get an initial checkpoint',
    )
    flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.")
    
    def convert2Uni(text):
        if isinstance(text, str):
            return text
        elif isinstance(text, bytes):
            return text.decode('utf-8','ignore')
        else:
            print(type(text))
            print('####################wrong################')
    
    
    def load_vocab(vocab_file):
        vocab = collections.OrderedDict()
        vocab.setdefault('blank', 2)
        index = 0
        with open(vocab_file) as reader:
        # with tf.gfile.GFile(vocab_file, 'r') as reader:
            while True:
                tmp = reader.readline()
                if not tmp:
                    break
                token = convert2Uni(tmp)
                token = token.strip()
                vocab[token] = index 
                index+=1
        return vocab
    
    
    def inputs(vectors, maxlen = 50):
        length = len(vectors)
        if length > maxlen:
            return vectors[0:maxlen], [1]*maxlen, [0]*maxlen
        else:
            input = vectors+[0]*(maxlen-length)
            mask = [1]*length + [0]*(maxlen-length)
            segment = [0]*maxlen
            return input, mask, segment
    
    
    def response_request(text):
        vectors = [dictionary.get('[CLS]')] + [dictionary.get(i) if i in dictionary else dictionary.get('[UNK]') for i in list(text)] + [dictionary.get('[SEP]')]
        input, mask, segment = inputs(vectors)
    
        input_ids = np.reshape(np.array(input), [1, -1])
        input_mask = np.reshape(np.array(mask), [1, -1])
        segment_ids = np.reshape(np.array(segment), [1, -1])
    
        embedding = tf.squeeze(model.get_sequence_output())
        rst = sess.run(embedding, feed_dict={'input_ids_p:0':input_ids, 'input_mask_p:0':input_mask, 'segment_ids_p:0':segment_ids})
    
        return json.dumps(rst.tolist(), ensure_ascii=False)
    
    
    dictionary = load_vocab(FLAGS.bert_vocab_file)
    init_checkpoint = FLAGS.init_checkpoint
    
    sess = tf.Session()
    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
    
    input_ids_p = tf.placeholder(shape=[None, None], dtype = tf.int32, name='input_ids_p')
    input_mask_p = tf.placeholder(shape=[None, None], dtype = tf.int32, name='input_mask_p')
    segment_ids_p = tf.placeholder(shape=[None, None], dtype = tf.int32, name='segment_ids_p')
    
    model = modeling.BertModel(
        config = bert_config,
        is_training = FLAGS.use_tpu,
        input_ids = input_ids_p,
        input_mask = input_mask_p,
        token_type_ids = segment_ids_p,
        use_one_hot_embeddings = FLAGS.use_tpu,
    )
    print('####################################')
    restore_saver = tf.train.Saver()
    restore_saver.restore(sess, init_checkpoint)
    
    print(response_request('我叫水奈樾。'))
    View Code

     

  • 相关阅读:
    Java SPI机制详解
    ElasticSearch核心概念和原理
    多线程基础-线程创建、线程方法、线程状态等
    MyBatis-日志、分页、一对多or多对一
    MyBatis-注解开发、XML全局配置
    SpringMVC-数据输出、Map、Model、视图解析、处理Json
    SpringMVC-@RequestMapping、@PathVariable、Rest、POJO封装、乱码问题
    lazy loading img 图片延迟加载
    google 地图,多个标记 js库
    Jquery各版本下载,附Jquery官网下载方法
  • 原文地址:https://www.cnblogs.com/rucwxb/p/10367609.html
Copyright © 2011-2022 走看看