zoukankan      html  css  js  c++  java
  • 关于bert+lstm+crf实体识别训练数据的构建

    一.在实体识别中,bert+lstm+crf也是近来常用的方法。这里的bert可以充当固定的embedding层,也可以用来和其它模型一起训练fine-tune。大家知道输入到bert中的数据需要一定的格式,如在单个句子的前后需要加入"[CLS]"和“[SEP]”,需要mask等。下面使用pad_sequences对句子长度进行截断以及padding填充,使每个输入句子的长度一致。构造训练集后,下载中文的预训练模型并加载相应的模型和词表vocab以参数配置,最后并利用albert抽取句子的embedding,这个embedding可以作为一个下游任务和其它模型进行组合完成特定任务的训练。

      1 import torch
      2 from configs.base import config
      3 from model.modeling_albert import BertConfig, BertModel
      4 from model.tokenization_bert import BertTokenizer
      5 from keras.preprocessing.sequence import pad_sequences
      6 from torch.utils.data import TensorDataset, DataLoader, RandomSampler
      7 
      8 import os
      9 
     10 device = torch.device('cuda' if torch.cuda.is_available()  else "cpu")
     11 MAX_LEN = 10
     12 if __name__ == '__main__':
     13     bert_config = BertConfig.from_pretrained(str(config['albert_config_path']), share_type='all')
     14     base_path = os.getcwd()
     15     VOCAB = base_path + '/configs/vocab.txt'  # your path for model and vocab
     16     tokenizer = BertTokenizer.from_pretrained(VOCAB)
     17 
     18     # encoder text
     19     tag2idx={'[SOS]':101, '[EOS]':102, '[PAD]':0, 'B_LOC':1, 'I_LOC':2, 'O':3}
     20     sentences = ['我是中华人民共和国国民', '我爱祖国']
     21     tags = ['O O B_LOC I_LOC I_LOC I_LOC I_LOC I_LOC O O', 'O O O O']
     22 
     23     tokenized_text = [tokenizer.tokenize(sent) for sent in sentences]
     24     #利用pad_sequence对序列长度进行截断和padding
     25     input_ids = pad_sequences([tokenizer.convert_tokens_to_ids(txt) for txt in tokenized_text], #没法一条一条处理,只能2-d的数据,即多于一条样本,但是如果全部加载到内存是不是会爆
     26                               maxlen=MAX_LEN-2,
     27                               truncating='post',
     28                               padding='post',
     29                               value=0)
     30 
     31     tag_ids = pad_sequences([[tag2idx.get(tok) for tok in tag.split()] for tag in tags],
     32                             maxlen=MAX_LEN-2,
     33                             padding="post",
     34                             truncating="post",
     35                             value=0)
     36 
     37     #bert中的句子前后需要加入[CLS]:101和[SEP]:102
     38     input_ids_cls_sep = []
     39     for input_id in input_ids:
     40         linelist = []
     41         linelist.append(101)
     42         flag = True
     43         for tag in input_id:
     44             if tag > 0:
     45                 linelist.append(tag)
     46             elif tag == 0 and flag:
     47                 linelist.append(102)
     48                 linelist.append(tag)
     49                 flag = False
     50             else:
     51                 linelist.append(tag)
     52         if tag > 0:
     53             linelist.append(102)
     54         input_ids_cls_sep.append(linelist)
     55 
     56     tag_ids_cls_sep = []
     57     for tag_id in tag_ids:
     58         linelist = []
     59         linelist.append(101)
     60         flag = True
     61         for tag in tag_id:
     62             if tag > 0:
     63                 linelist.append(tag)
     64             elif tag == 0 and flag:
     65                 linelist.append(102)
     66                 linelist.append(tag)
     67                 flag = False
     68             else:
     69                 linelist.append(tag)
     70         if tag > 0:
     71             linelist.append(102)
     72         tag_ids_cls_sep.append(linelist)
     73 
     74     attention_masks = [[int(tok > 0) for tok in line] for line in input_ids_cls_sep]
     75 
     76     print('---------------------------')
     77     print('input_ids:{}'.format(input_ids_cls_sep))
     78     print('tag_ids:{}'.format(tag_ids_cls_sep))
     79     print('attention_masks:{}'.format(attention_masks))
     80 
     81 
     82     # input_ids = torch.tensor([tokenizer.encode('我 是 中 华 人 民 共 和 国 国 民', add_special_tokens=True)]) #为True则句子首尾添加[CLS]和[SEP]
     83     # print('input_ids:{}, size:{}'.format(input_ids, len(input_ids)))
     84     # print('attention_masks:{}, size:{}'.format(attention_masks, len(attention_masks)))
     85 
     86     inputs_tensor = torch.tensor(input_ids_cls_sep)
     87     tags_tensor = torch.tensor(tag_ids_cls_sep)
     88     masks_tensor = torch.tensor(attention_masks)
     89 
     90     train_data = TensorDataset(inputs_tensor, tags_tensor, masks_tensor)
     91     train_sampler = RandomSampler(train_data)
     92     train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=2)
     93 
     94     model = BertModel.from_pretrained(config['bert_dir'],config=bert_config)
     95     model.to(device)
     96     model.eval()
     97     with torch.no_grad():
     98         '''
     99         note:
    100         一.
    101         如果设置:"output_hidden_states":"True"和"output_attentions":"True"
    102         输出的是: 所有层的 sequence_output, pooled_output, (hidden_states), (attentions)
    103         则 all_hidden_states, all_attentions = model(input_ids)[-2:]
    104 
    105         二.
    106         如果没有设置:output_hidden_states和output_attentions
    107         输出的是:最后一层  --> (output_hidden_states, output_attentions)
    108        '''
    109         for index, batch in enumerate(train_dataloader):
    110             batch = tuple(t.to(device) for t in batch)
    111             b_input_ids, b_input_mask, b_labels = batch
    112             last_hidden_state = model(input_ids = b_input_ids,attention_mask = b_input_mask)
    113             print(len(last_hidden_state))
    114             all_hidden_states, all_attentions = last_hidden_state[-2:] #这里获取所有层的hidden_satates以及attentions
    115             print(all_hidden_states[-2].shape)#倒数第二层hidden_states的shape
             print(all_hidden_states[-2])

    二.打印结果

    input_ids:[[101, 2769, 3221, 704, 1290, 782, 3696, 1066, 1469, 102], [101, 2769, 4263, 4862, 1744, 102, 0, 0, 0, 0]]
    tag_ids:[[101, 3, 3, 1, 2, 2, 2, 2, 2, 102], [101, 3, 3, 3, 3, 102, 0, 0, 0, 0]]
    attention_masks:[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 0, 0, 0, 0]]
    4
    torch.Size([2, 10, 768])
    tensor([[[-1.1074, -0.0047,  0.4608,  ..., -0.1816, -0.6379,  0.2295],
             [-0.1930, -0.4629,  0.4127,  ..., -0.5227, -0.2401, -0.1014],
             [ 0.2682, -0.6617,  0.2744,  ..., -0.6689, -0.4464,  0.1460],
             ...,
             [-0.1723, -0.7065,  0.4111,  ..., -0.6570, -0.3490, -0.5541],
             [-0.2028, -0.7025,  0.3954,  ..., -0.6566, -0.3653, -0.5655],
             [-0.2026, -0.6831,  0.3778,  ..., -0.6461, -0.3654, -0.5523]],

            [[-1.3166, -0.0052,  0.6554,  ..., -0.2217, -0.5685,  0.4270],
             [-0.2755, -0.3229,  0.4831,  ..., -0.5839, -0.1757, -0.1054],
             [-1.4941, -0.1436,  0.8720,  ..., -0.8316, -0.5213, -0.3893],
             ...,
             [-0.7022, -0.4104,  0.5598,  ..., -0.6664, -0.1627, -0.6270],
             [-0.7389, -0.2896,  0.6083,  ..., -0.7895, -0.2251, -0.4088],
             [-0.0351, -0.9981,  0.0660,  ..., -0.4606,  0.4439, -0.6745]]])

  • 相关阅读:
    rabbitmq入门
    php7.2 安装redis扩展
    php安装扩展的几种方法
    yum安装php7.2
    相关报错
    [枚举]P1089 津津的储蓄计划
    [DFS]排列的生成
    [枚举]P1085 不高兴的津津
    [模拟]P1047 校门外的树
    [模拟]P1046 陶陶摘苹果
  • 原文地址:https://www.cnblogs.com/little-horse/p/11731552.html
Copyright © 2011-2022 走看看