zoukankan      html  css  js  c++  java
  • Pytorch-手动实现Bert的训练过程(简写版)

    导包:

    1 import re
    2 import math
    3 import torch
    4 import numpy as np
    5 from random import *
    6 import torch.nn as nn
    7 import torch.optim as optim
    8 import torch.utils.data as Data

    1.数据预处理

    1.1构造单词表和映射

     1 text = (
     2     'Hello, how are you? I am Romeo.
    ' # R
     3     'Hello, Romeo My name is Juliet. Nice to meet you.
    ' # J
     4     'Nice to meet you too. How are you today?
    ' # R
     5     'Great. My baseball team won the competition.
    ' # J
     6     'Oh Congratulations, Juliet
    ' # R
     7     'Thank you Romeo
    ' # J
     8     'Where are you going today?
    ' # R
     9     'I am going shopping. What about you?
    ' # J
    10     'I am going to visit my grandmother. she is not very well' # R
    11 )
    12 sentences = re.sub("[.,!?\-]", '', text.lower()).split('
    ')    # filter '.', ',', '?', '!'
    13 word_list = list(set(" ".join(sentences).split()))               # ['hello', 'how', 'are', 'you',...]
    14 word2idx = {'[PAD]' : 0, '[CLS]' : 1, '[SEP]' : 2, '[MASK]' : 3}
    15 for i, w in enumerate(word_list):
    16     word2idx[w] = i + 4
    17 idx2word = {i: w for i, w in enumerate(word2idx)}
    18 vocab_size = len(word2idx)         #40
    19 
    20 token_list = list()                #token_list存储了每一句的token 
    21 for sentence in sentences:
    22     arr = [word2idx[s] for s in sentence.split()]
    23     token_list.append(arr)

    展示一下:

    1 print(sentences[0])                  #hello how are you i am romeo
    2 print(token_list[0])                 #[28, 22, 27, 35, 11, 4, 15]

    1.2设置超参数

     1 maxlen = 30      # 句子pad到的最大长度,即下面句子中的seq_len
     2 batch_size = 6
     3 max_pred = 5     # max tokens of prediction
     4 n_layers = 6     # Bert中Transformer的层数 
     5 n_heads = 12
     6 d_model = 768    # 即embedding_dim
     7 d_ff = 768*4     # 4*d_model, FeedForward dimension
     8 d_k = d_v = 64   # dimension of K(=Q), V,是d_model分割成n_heads之后的长度
     9 n_segments = 2   # 分隔句子数

    2.实现Dataloader

    2.1生成data

     随机mask语料中15%的token(在mask时,80%的单词用[MASK]来代替,10%单词用任意非标记词代替)  

     1 def make_data():
     2     batch = []
     3     positive = negative = 0
     4     while positive != batch_size/2 or negative != batch_size/2:
     5         tokens_a_index, tokens_b_index = randrange(len(sentences)), randrange(len(sentences))   # sample random index in sentences
     6         tokens_a, tokens_b = token_list[tokens_a_index], token_list[tokens_b_index]
     7         input_ids = [word2idx['[CLS]']] + tokens_a + [word2idx['[SEP]']] + tokens_b + [word2idx['[SEP]']]   #单词在词典中的编码 
     8         segment_ids = [0] * (1 + len(tokens_a) + 1) + [1] * (len(tokens_b) + 1)                             #区分两个句子的编码(上句全为0,下句全为1)
     9  
    10         # MASK LM
    11         n_pred =  min(max_pred, max(1, int(len(input_ids) * 0.15)))                     # 15 % of tokens in one sentence
    12         cand_maked_pos = [i for i, token in enumerate(input_ids)
    13                           if token != word2idx['[CLS]'] and token != word2idx['[SEP]']] # candidate masked position
    14         shuffle(cand_maked_pos)
    15         masked_tokens, masked_pos = [], []       #被覆盖的标记,被覆盖的索引号 
    16         for pos in cand_maked_pos[:n_pred]:
    17             masked_pos.append(pos)
    18             masked_tokens.append(input_ids[pos])
    19             if random() < 0.8:     # 80%
    20                 input_ids[pos] = word2idx['[MASK]']   # make mask
    21             elif random() > 0.9:   # 10%
    22                 index = randint(0, vocab_size - 1)    # random index in vocabulary
    23                 while index < 4:   # can't involve 'CLS', 'SEP', 'PAD'
    24                   index = randint(0, vocab_size - 1)
    25                 input_ids[pos] = index     # replace
    26  
    27         # Zero Paddings
    28         n_pad = maxlen - len(input_ids)
    29         input_ids.extend([0] * n_pad)           
    30         segment_ids.extend([0] * n_pad)
    31 
    32         # Zero Padding (100% - 15%) tokens
    33         if max_pred > n_pred:
    34             n_pad = max_pred - n_pred
    35             masked_tokens.extend([0] * n_pad)
    36             masked_pos.extend([0] * n_pad)
    37 
    38         if tokens_a_index + 1 == tokens_b_index and positive < batch_size/2:
    39             batch.append([input_ids, segment_ids, masked_tokens, masked_pos, True]) # IsNext
    40             positive += 1
    41         elif tokens_a_index + 1 != tokens_b_index and negative < batch_size/2:
    42             batch.append([input_ids, segment_ids, masked_tokens, masked_pos, False]) # NotNext
    43             negative += 1
    44     return batch

    调用上面函数:

    1 batch = make_data()
    2 
    3 input_ids, segment_ids, masked_tokens, masked_pos, isNext = zip(*batch)             #全部要转成LongTensor类型
    4 input_ids, segment_ids, masked_tokens, masked_pos, isNext = 
    5     torch.LongTensor(input_ids),  torch.LongTensor(segment_ids), torch.LongTensor(masked_tokens),
    6     torch.LongTensor(masked_pos), torch.LongTensor(isNext)

    2.2生成DataLoader

     1 class MyDataSet(Data.Dataset):
     2     def __init__(self, input_ids, segment_ids, masked_tokens, masked_pos, isNext):
     3         self.input_ids = input_ids
     4         self.segment_ids = segment_ids
     5         self.masked_tokens = masked_tokens
     6         self.masked_pos = masked_pos
     7         self.isNext = isNext 
     8   
     9     def __len__(self):
    10         return len(self.input_ids)
    11   
    12     def __getitem__(self, idx):
    13         return self.input_ids[idx], self.segment_ids[idx], self.masked_tokens[idx], self.masked_pos[idx], self.isNext[idx]
    14 
    15 loader = Data.DataLoader(MyDataSet(input_ids, segment_ids, masked_tokens, masked_pos, isNext), batch_size, True)

    查看下loader的结果:

    1 print(next(iter(loader)))
     1 [tensor([[ 1, 36, 21, 38,  2, 30,  8, 10,  2,  0,  0,  0,  0,  0,  0,  0,  0,  0,
     2           0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
     3         [ 1, 27, 10, 22,  9, 23, 38, 39, 20, 19,  8,  2, 33,  3, 11, 26, 35, 32,
     4          12,  2,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
     5         [ 1, 39, 20, 19,  8,  6, 34,  5,  8,  3,  2, 27,  3, 22,  9,  3, 38, 39,
     6          20, 19,  8,  2,  0,  0,  0,  0,  0,  0,  0,  0],
     7         [ 1, 18,  3, 13, 14, 29,  4,  8,  2, 18,  3, 13, 20, 24, 22,  7, 37, 23,
     8          28, 16, 31,  2,  0,  0,  0,  0,  0,  0,  0,  0],
     9         [ 1,  3, 17, 13, 14, 29,  4,  8,  2, 18, 17, 13, 20, 24, 22,  3, 37, 23,
    10          28, 16, 31,  2,  0,  0,  0,  0,  0,  0,  0,  0],
    11         [ 1, 33, 22, 11, 26, 35, 32, 12,  2,  3,  8, 10,  2,  0,  0,  0,  0,  0,
    12           0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0]]), 
    13 tensor([[0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    14          0, 0, 0, 0, 0, 0],
    15         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0,
    16          0, 0, 0, 0, 0, 0],
    17         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0,
    18          0, 0, 0, 0, 0, 0],
    19         [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0,
    20          0, 0, 0, 0, 0, 0],
    21         [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0,
    22          0, 0, 0, 0, 0, 0],
    23         [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    24          0, 0, 0, 0, 0, 0]]), 
    25 tensor([[21,  0,  0,  0,  0],
    26         [22, 27,  8,  0,  0],
    27         [23, 25, 10,  0,  0],
    28         [18, 17, 17,  0,  0],
    29         [17, 18,  7,  0,  0],
    30         [30,  0,  0,  0,  0]]),
    31 tensor([[ 2,  0,  0,  0,  0],
    32         [13,  1, 10,  0,  0],
    33         [15,  9, 12,  0,  0],
    34         [ 9, 10,  2,  0,  0],
    35         [10,  1, 15,  0,  0],
    36         [ 9,  0,  0,  0,  0]]), 
    37 tensor([1, 0, 0, 1, 1, 0])]

    3.Bert模型

    3.1Embedding

     1 class BertEmbedding(nn.Module):
     2     def __init__(self):
     3         super(BertEmbedding, self).__init__()
     4         self.tok_embed = nn.Embedding(vocab_size, d_model)  # token embedding
     5         self.pos_embed = nn.Embedding(maxlen, d_model)      # position embedding 这里简写了,源码中位置编码使用了sin,cos
     6         self.seg_embed = nn.Embedding(n_segments, d_model)  # segment(token type) embedding
     7         self.norm = nn.LayerNorm(d_model)
     8 
     9     def forward(self, x, seg):                              #x和pos的shape都是[batch_size, seq_len] 
    10         seq_len = x.size(1) 
    11         pos = torch.arange(seq_len, dtype=torch.long)
    12         pos = pos.unsqueeze(0).expand_as(x)                 # [seq_len] -> [batch_size, seq_len]
    13         embedding = self.tok_embed(x) + self.pos_embed(pos) + self.seg_embed(seg)       #三个embedding相加 
    14         return self.norm(embedding)

    3.2生成mask

    1 def get_attn_pad_mask(seq_q, seq_k):        #seq_q和seq_k的shape都是[batch_size, seq_len]
    2     batch_size, seq_len = seq_q.size()
    3     # eq(zero) is PAD token
    4     pad_attn_mask = seq_q.data.eq(0).unsqueeze(1)              # [batch_size, 1, seq_len]
    5     return pad_attn_mask.expand(batch_size, seq_len, seq_len)  # [batch_size, seq_len, seq_len]

    3.3构建激活函数

    1 def gelu(x):    
    2     return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))

    3.4缩放点乘注意力计算

     1 class ScaledDotProductAttention(nn.Module):    #点乘计算注意力 
     2     def __init__(self):
     3         super(ScaledDotProductAttention, self).__init__()
     4 
     5     def forward(self, Q, K, V, attn_mask):          #对于下面调用来说,Q,K,V的维度都是[batch_size, n_heads, seq_len, d_k]
     6         scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k)     #[batch_size, n_heads, seq_len, seq_len]
     7         scores.masked_fill_(attn_mask, -1e9)        #mask==0的内容填充1e-9,使得计算softmax时概率接近0
     8         attn = nn.Softmax(dim=-1)(scores)           #对最后一个维度归一化,[batch_size, n_heads, seq_len, seq_len]
     9         context = torch.matmul(attn, V)             #[batch_size, n_heads, seq_len, d_k]
    10         return context    

    3.5多头注意力

     1     
     2 class MultiHeadAttention(nn.Module):
     3     def __init__(self):
     4         super(MultiHeadAttention, self).__init__()
     5         self.W_Q = nn.Linear(d_model, d_k * n_heads)   #其实就是[d_model, d_model]
     6         self.W_K = nn.Linear(d_model, d_k * n_heads)
     7         self.W_V = nn.Linear(d_model, d_v * n_heads)
     8         
     9     def forward(self, Q, K, V, attn_mask):             #Q和K: [batch_size, seq_len, d_model], V: [batch_size, seq_len, d_model], attn_mask: [batch_size, seq_len, seq_len]
    10         residual, batch_size = Q, Q.size(0)
    11         # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W)
    12         q_s = self.W_Q(Q).view(batch_size, -1, n_heads, d_k).transpose(1,2)  # q_s: [batch_size, n_heads, seq_len, d_k]
    13         k_s = self.W_K(K).view(batch_size, -1, n_heads, d_k).transpose(1,2)  # k_s: [batch_size, n_heads, seq_len, d_k]
    14         v_s = self.W_V(V).view(batch_size, -1, n_heads, d_v).transpose(1,2)  # v_s: [batch_size, n_heads, seq_len, d_v]
    15 
    16         attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1)          # attn_mask : [batch_size, n_heads, seq_len, seq_len]
    17  
    18         # context: [batch_size, n_heads, seq_len, d_v], attn_mask: [batch_size, n_heads, seq_len, seq_len]
    19         context = ScaledDotProductAttention()(q_s, k_s, v_s, attn_mask)
    20         context = context.transpose(1, 2).contiguous().view(batch_size, -1, n_heads * d_v) # context: [batch_size, seq_len, n_heads, d_v]
    21         output = nn.Linear(n_heads * d_v, d_model)(context)
    22         return nn.LayerNorm(d_model)(output + residual)                      # output: [batch_size, seq_len, d_model]    
    23    

    3.6前向传播

    1 class PoswiseFeedForwardNet(nn.Module):        #前向传播,线性激活再线性
    2     def __init__(self):
    3         super(PoswiseFeedForwardNet, self).__init__()
    4         self.fc1 = nn.Linear(d_model, d_ff)
    5         self.fc2 = nn.Linear(d_ff, d_model)
    6 
    7     def forward(self, x):
    8         # [batch_size, seq_len, d_model] -> [batch_size, seq_len, d_ff] -> [batch_size, seq_len, d_model]
    9         return self.fc2(gelu(self.fc1(x)))

    3.7编码层EncoderLayer

    源码中Bidirectional Encoder = Transformer (self-attention)

    Transformer = MultiHead_Attention + Feed_Forward with sublayer connection,下面代码省去了sublayer。

     1 class EncoderLayer(nn.Module):    #多头注意力和前向传播的组合 
     2     def __init__(self):
     3         super(EncoderLayer, self).__init__()
     4         self.enc_self_attn = MultiHeadAttention()
     5         self.pos_ffn = PoswiseFeedForwardNet()
     6 
     7     def forward(self, enc_inputs, enc_self_attn_mask):
     8         enc_outputs = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask) # enc_inputs to same Q,K,V
     9         enc_outputs = self.pos_ffn(enc_outputs)             # enc_outputs: [batch_size, seq_len, d_model]
    10         return enc_outputs

    3.8Bert模型

     1 class BERT(nn.Module):
     2     def __init__(self):
     3         super(BERT, self).__init__()
     4         self.embedding = BertEmbedding() 
     5         self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])
     6         self.fc = nn.Sequential(
     7             nn.Linear(d_model, d_model),
     8             nn.Dropout(0.5),
     9             nn.Tanh(),
    10         )
    11         self.classifier = nn.Linear(d_model, 2)
    12         self.linear = nn.Linear(d_model, d_model)
    13         self.activ2 = gelu
    14         # fc2 is shared with embedding layer 
    15         embed_weight = self.embedding.tok_embed.weight
    16         self.fc2 = nn.Linear(d_model, vocab_size, bias=False) 
    17         self.fc2.weight = embed_weight 
    18 
    19     def forward(self, input_ids, segment_ids, masked_pos):          #input_ids和segment_ids的shape[batch_size, seq_len],masked_pos的shape是[batch_size, max_pred]
    20         output = self.embedding(input_ids, segment_ids)             # [bach_size, seq_len, d_model]
    21         
    22         enc_self_attn_mask = get_attn_pad_mask(input_ids, input_ids)# [batch_size, seq_len, seq_len]
    23         for layer in self.layers:                          #这里对layers遍历,相当于源码中多个transformer_blocks
    24             output = layer(output, enc_self_attn_mask)              # output: [batch_size, seq_len, d_model]
    25         
    26         # it will be decided by first token(CLS) 
    27         h_pooled = self.fc(output[:, 0])                   # [batch_size, d_model]
    28         logits_clsf = self.classifier(h_pooled)            # [batch_size, 2] predict isNext
    29 
    30         masked_pos = masked_pos[:, :, None].expand(-1, -1, d_model) # [batch_size, max_pred, d_model]
    31         h_masked = torch.gather(output, 1, masked_pos)              # masking position [batch_size, max_pred, d_model]
    32         h_masked = self.activ2(self.linear(h_masked))               # [batch_size, max_pred, d_model]
    33         logits_lm = self.fc2(h_masked)                              # [batch_size, max_pred, vocab_size]
    34         return logits_lm, logits_clsf                     #logits_lm: [batch_size, max_pred, vocab_size], logits_clsf: [batch_size, 2]

    3.9定义模型

    1 model = BERT()
    2 criterion = nn.CrossEntropyLoss()
    3 optimizer = optim.Adadelta(model.parameters(), lr=0.001)

    4.训练模型

     1 for epoch in range(50):
     2     for input_ids, segment_ids, masked_tokens, masked_pos, isNext in loader:
     3       logits_lm, logits_clsf = model(input_ids, segment_ids, masked_pos)          #logits_lm: [batch_size, max_pred, vocab_size], logits_clsf: [batch_size, 2]
     4       
     5       loss_lm = criterion(logits_lm.view(-1, vocab_size), masked_tokens.view(-1)) #for masked LM
     6       loss_lm = (loss_lm.float()).mean()
     7       loss_clsf = criterion(logits_clsf, isNext) # for sentence classification
     8       loss = loss_lm + loss_clsf
     9       if (epoch + 1) % 10 == 0:
    10           print('Epoch:', '%04d' % (epoch + 1), 'loss =', '{:.6f}'.format(loss))
    11       optimizer.zero_grad()
    12       loss.backward()
    13       optimizer.step()

    Epoch: 0010 loss = 2.000775
    Epoch: 0020 loss = 1.244515
    Epoch: 0030 loss = 0.993536
    Epoch: 0040 loss = 0.994489
    Epoch: 0050 loss = 0.929772

    5.预测

    先查看下数据:

    1 input_ids, segment_ids, masked_tokens, masked_pos, isNext = batch[1]
    2 print(text)
    3 print('================================')
    4 print([idx2word[w] for w in input_ids if idx2word[w] != '[PAD]'])

    Hello, how are you? I am Romeo.
    Hello, Romeo My name is Juliet. Nice to meet you.
    Nice to meet you too. How are you today?
    Great. My baseball team won the competition.
    Oh Congratulations, Juliet
    Thank you Romeo
    Where are you going today?
    I am going shopping. What about you?
    I am going to visit my grandmother. she is not very well
    ================================
    ['[CLS]', 'nice', 'to', 'meet', 'you', 'too', '[MASK]', 'are', 'you', 'today', '[SEP]', 'i', '[MASK]', 'going', 'to', 'visit', 'my', 'grandmother', 'she', 'is', '[MASK]', 'very', 'well', '[SEP]']

    1 logits_lm, logits_clsf = model(torch.LongTensor([input_ids]), torch.LongTensor([segment_ids]), torch.LongTensor([masked_pos]))
    2 logits_lm = logits_lm.data.max(2)[1][0].data.numpy() 
    3 print('masked tokens list : ',[pos for pos in masked_tokens if pos != 0])
    4 print('predict masked tokens list : ',[pos for pos in logits_lm if pos != 0])

    masked tokens list : [32, 22, 4]
    predict masked tokens list : [32, 34, 4]

    1 logits_clsf = logits_clsf.data.max(1)[1].data.numpy()[0]
    2 print('isNext : ', True if isNext else False)
    3 print('predict isNext : ',True if logits_clsf else False)

    isNext : False
    predict isNext : False

  • 相关阅读:
    C#实现一个最简单的HTTP服务器
    WinForm特效:桌面上的遮罩层
    C#利用GDI+绘制旋转文字等效果
    C#程序通过模板自动创建Word文档
    C#生成软件注册码
    C# 利用WORD模板和标签(bookmark) 批量生成WORD
    绝对好文:.NET程序性能的基本要领
    Mysql学习---SQL测试题之表结构
    Mysql学习---基础操作学习2
    Mysql学习---SQL语言的四大分类
  • 原文地址:https://www.cnblogs.com/cxq1126/p/13713264.html
Copyright © 2011-2022 走看看