zoukankan      html  css  js  c++  java
  • Pytorch-手动实现Bert的训练过程(简写版)

    视频讲解

    直接看这个-->Github

    导包:

    import re
    import math
    import torch
    import numpy as np
    from random import *
    import torch.nn as nn
    import torch.optim as optim
    import torch.utils.data as Data
    

    1. 数据预处理

    1.1 构造单词表和映射

    text = (
        'Hello, how are you? I am Romeo.
    '                   # R
        'Hello, Romeo My name is Juliet. Nice to meet you.
    ' # J
        'Nice to meet you too. How are you today?
    '          # R
        'Great. My baseball team won the competition.
    '      # J
        'Oh Congratulations, Juliet
    '                        # R
        'Thank you Romeo
    '                                   # J
        'Where are you going today?
    '                        # R
        'I am going shopping. What about you?
    '              # J
        'I am going to visit my grandmother. she is not very well' # R
    )
    sentences = re.sub("[.,!?\-]", '', text.lower()).split('
    ')    # filter '.', ',', '?', '!'
    
    # 所有句子的单词list
    word_list = list(set(" ".join(sentences).split()))               # ['hello', 'how', 'are', 'you',...]
    
    # 给单词表中所有单词设置序号
    word2idx = {'[PAD]' : 0, '[CLS]' : 1, '[SEP]' : 2, '[MASK]' : 3}
    for i, w in enumerate(word_list):
        word2idx[w] = i + 4
    
    # 用于 idx 映射回 word
    idx2word = {i: w for i, w in enumerate(word2idx)}
    vocab_size = len(word2idx)         # 40
    
    # token: 就是每个单词在词表中的index
    token_list = list()                # token_list存储了每一句的token
    for sentence in sentences:
        arr = [word2idx[s] for s in sentence.split()]
        token_list.append(arr)
    

    展示一下:

    print(sentences[1])   # hello romeo my name is juliet nice to meet you
    print(token_list[1])  # [14, 31, 35, 33, 27, 11, 8, 16, 5, 34]
    

    1.2 设置超参数

    maxlen = 30      # 句子pad到的最大长度,即下面句子中的seq_len
    batch_size = 6 
    
    # BERT模型参数
    max_pred = 5     # max tokens of prediction
    n_layers = 6     # Bert中Transformer的层数
    n_heads = 12     # Multi-head的数量
    d_model = 768    # 即embedding_dim
    d_ff = 768*4     # 4*d_model, FeedForward dimension
    d_k = d_v = 64   # dimension of K(=Q), V,是d_model分割成n_heads之后的长度, 768 // 12 = 64
    
    n_segments = 2   # 分隔句子数
    

    2.实现Dataloader

    2.1生成data

    • 选中语料中所有词的15%进行随机mask

    • 在确定要Mask掉的单词之后:

      • 选中的单词,在80%的概率下被用 [MASK] 来代替

      • 选中的单词,在10%的概率下不做mask,用任意非标记词代替

      • 选中的单词,在10%的概率下不做mask,仍然保留原来真实的词

    # sample IsNext and NotNext to be same in small batch size
    def make_data():
        batch = []
        positive = negative = 0
        while (positive != batch_size / 2) or (negative != batch_size / 2):
            # ==========================BERT 的 input 表示================================
            # 随机取两个句子的index
            tokens_a_index, tokens_b_index = randrange(len(sentences)), randrange(len(sentences)) # sample random index in sentences
            # 随机取两个句子
            tokens_a, tokens_b = token_list[tokens_a_index], token_list[tokens_b_index]
            # Token Embeddings (没有使用word piece): 单词在词典中的编码 
            input_ids = [word2idx['[CLS]']] + tokens_a + [word2idx['[SEP]']] + tokens_b + [word2idx['[SEP]']]
            # Segment Embeddings: 区分两个句子的编码(上句全为0 (CLS~SEP),下句全为1)
            segment_ids = [0] * (1 + len(tokens_a) + 1) + [1] * (len(tokens_b) + 1)
            
            # ========================== MASK LM ==========================================
            n_pred = min(max_pred, max(1, int(len(input_ids) * 0.15)))                        # 15 % of tokens in one sentence
            # token在 input_ids 中的下标(不包括[CLS], [SEP])
            cand_maked_pos = [i for i, token in enumerate(input_ids) 
                              if token != word2idx['[CLS]'] and token != word2idx['[SEP]']]  # candidate masked position
            shuffle(cand_maked_pos)
            
            masked_tokens, masked_pos = [], []     # 被mask的tokens,被mask的tokens的索引号
            for pos in cand_maked_pos[:n_pred]:   #  随机mask 15% 的tokens
                masked_pos.append(pos)
                masked_tokens.append(input_ids[pos])
                # 选定要mask的词
                if random() < 0.8:                         # 80%:被真实mask
                    input_ids[pos] = word2idx['[MASK]']
                elif random() > 0.9:                       # 10%
                    index = randint(0, vocab_size - 1)     # random index in vocabulary
                    while index < 4:                       # 不能是 [PAD], [CLS], [SEP], [MASK]
                        index = randint(0, vocab_size - 1)
                    input_ids[pos] = index                 # 10%:不做mask,用任意非标记词代替
                # 还有10%:不做mask,什么也不做
                
            # ==========================+ Paddings ======================================
            # input_ids全部padding到相同的长度
            n_pad = maxlen - len(input_ids)
            input_ids.extend(word2idx['[PAD]'] * n_pad)
            segment_ids.extend(word2idx['[PAD]'] * n_pad)
                
            # zero padding (100% - 15%) tokens
            if max_pred > n_pred:
                n_pad = max_pred - n_pred
                masked_tokens.extend([0] * n_pad)
                masked_pos.extend([0] * n_pad)
            
            # 让正例 和 负例 数量相同
            if tokens_a_index + 1 == tokens_b_index and positive < batch_size / 2:
                batch.append([input_ids, segment_ids, masked_tokens, masked_pos, True])  # IsNext
                positive += 1
            elif tokens_a_index + 1 != tokens_b_index and negative < batch_size / 2:
                batch.append([input_ids, segment_ids, masked_tokens, masked_pos, False]) # NotNext
                negative += 1
            
            return batch
    

    调用上面函数:(一个batch的数据)

    batch = make_data()
    
    # 一个batch的数据
    input_ids, segment_ids, masked_tokens, masked_pos, isNext = zip(*batch)  
    # 全部要转成LongTensor类型
    input_ids, segment_ids, masked_tokens, masked_pos, isNext = 
        torch.LongTensor(input_ids), torch.LongTensor(segment_ids), torch.LongTensor(masked_tokens), 
        torch.LongTensor(masked_pos), torch.LongTensor(isNext)
    

    2.2 实现DataLoader

    • 为了使用dataloader,我们需要定义以下两个function:

      • __len__ function:需要返回整个数据集中有多少个item

      • __get__ :根据给定的index返回一个item

    有了dataloader之后,我们可以轻松随机打乱整个数据集,拿到一个batch的数据等等。

    class MyDataSet(Data.Dataset):
        def __init__(self, input_ids, segment_ids, masked_tokens, masked_pos, isNext):
            # 全部要转成LongTensor类型
            self.input_ids = torch.LongTensor(input_ids)
            self.segment_ids = torch.LongTensor(segment_ids)
            self.masked_tokens = torch.LongTensor(masked_tokens) 
            self.masked_pos = torch.LongTensor(masked_pos) 
            self.isNext = torch.LongTensor(isNext)
            
        def __len__(self):
            return len(self.input_ids)
        
        def __getitem__(self, idx):
            return self.input_ids[idx], self.segment_ids[idx], self.masked_tokens[idx], self.masked_pos[idx], self.isNext[idx]
        
    dataset = MyDataSet(input_ids, segment_ids, masked_tokens, masked_pos, isNext)
    dataloader = Data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
    

    查看数据:

    print(next(iter(dataloader)))
    print(len(dataloader))           # 就一个batch
    

    输出:

    [tensor([[ 1,  3, 13, 11,  2,  7, 34, 31,  2,  0,  0,  0,  0,  0,  0,  0,  0,  0,
              0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
            [ 1,  4, 23,  3, 16, 17, 35, 30, 18, 27, 29, 36, 24,  2,  3, 13, 11,  2,
              0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
            [ 1,  6, 13, 11,  2,  3, 34, 31,  2,  0,  0,  0,  0,  0,  0,  0,  0,  0,
              0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
            [ 1,  3, 23, 37, 39, 26,  3, 34,  2,  4, 23, 37, 16,  3, 35, 30, 18, 27,
             29, 36, 24,  2,  0,  0,  0,  0,  0,  0,  0,  0],
            [ 1,  7, 34, 31,  2,  4, 23, 37, 39, 26, 21, 34,  2,  0,  0,  0,  0,  0,
              0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
            [ 1,  7, 34, 31,  2,  8, 16,  3, 34, 32, 19, 12, 34, 28,  2,  0,  0,  0,
              0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0]]), tensor([[0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
             0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0,
             0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
             0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0,
             0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
             0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
             0, 0, 0, 0, 0, 0]]), tensor([[ 6,  0,  0,  0,  0],
            [ 6, 37,  0,  0,  0],
            [ 7,  0,  0,  0,  0],
            [17, 21,  4,  0,  0],
            [ 7,  0,  0,  0,  0],
            [ 5, 34,  0,  0,  0]]), tensor([[ 1,  0,  0,  0,  0],
            [14,  3,  0,  0,  0],
            [ 5,  0,  0,  0,  0],
            [13,  6,  1,  0,  0],
            [ 1,  0,  0,  0,  0],
            [ 7,  2,  0,  0,  0]]), tensor([1, 0, 1, 1, 0, 0])]
    1
    
  • 相关阅读:
    NTP时间同步
    《暗时间》
    寻找字典公共键
    maven pom.xml的execution报错
    maven安装scala插件
    html 和xml
    sparkstreaming+kafka
    zookeeper错误Error contacting service. It is probably not running.
    eclipse开发hive2程序
    eclipse开发hbase程序
  • 原文地址:https://www.cnblogs.com/douzujun/p/13557433.html
Copyright © 2011-2022 走看看