zoukankan      html  css  js  c++  java
  • bert的训练数据的简单构建

    一.简介

    大家都知道原始bert预训练模型有两大任务:
    
                                    1.masked lm:带mask的语言模型
    
                                    2.next sentence prediction:是否为下一句话
    
    bert模型的训练数据有三部分,如下图:
    
                                1.字的token embeddings
    
                                2.句子的embeddings
    
                                3.句子位置的embeddings

    下面就简单的构建一个bert的训练数据



    二.程序
    import re
    import math
    import numpy as np
    import random
    
    text = (
        '随后,文章为中美关系未来发展提出了5点建议。
    '
        '第一,美国应恢复“和平队”等在华奖学金项目。
    '
        '文章称,这些项目在过去几十年帮助美国了解中国,却被特朗普政府因意图孤立中国而取消。
    '
        '第二,美国应停止污名化孔子学院。文章说,孔子学院只是文化中心和教育机构,性质类似于德国的歌德学院和英国文化协会。
    '
        '第三,美国应该允许此前被特朗普政府驱逐出境的中国记者回到美国。同时文章建议中国也允许美国记者入境。
    '
        '第四,美国应取消限制中共党员入境的做法。
    '
        '第五,美方应邀请中国重新开放中国驻休斯顿领事馆。
    '
        '文章称,如此一来,中国也将重新允许美国驻成都领事馆开放。
    '
        '文章最后表示,尽管这些都是微小的举动,但对建立中美互信很有意义,能够为解决更加棘手的问题铺设道路。' 
    )
    
    sentences = re.sub("[.。,“”,!?\-]", '', text.lower()).split('
    ') # 过滤特殊符号
    word_list = list("".join(sentences))
    # 以下是词典的构建
    word2idx = {'[pad]':0, '[cls]':1, '[sep]':2, '[unk]':3, '[mask]':4}
    
    for i, w in enumerate(word_list):
        word2idx[w] = i + 5
    idx2word = {i : w for i, w in enumerate(word2idx)}
    vocab_size = len(word2idx)
    
    token_list = list()
    for sentence in sentences:
        arr = [word2idx[s] for s in list(sentence)]
        token_list.append(arr)
    

      

    # 数据预处理,
    
    maxlen = 120 # 最大句子长度
    max_pred = 5 # 预测每个序列的单词个数
    batch_size = 6
    n_segments = 2 # 输入的几句话
    
    
    '''
    1.Next Sentence Prediction
    
    50%的情况下,句子B是句子A的下一句,而50%的情况下,B不是A的下一句
    
    2.Masked LM and the Masking Procedure
    随机mask一句话中的15% token,拼接2句话
    
        80% 的时间:用[MASK]替换目标单词
        10% 的时间:用随机的单词替换目标单词
        10% 的时间:不改变目标单词
    '''
    def build_data():
        batch = []
        positive = negative = 0
        while (positive != (batch_size/2)) or (negative != (batch_size/2)):
            # 随机选择句子的index,作为A,B句
            tokens_a_index, tokens_b_index = random.randrange(len(sentences)), random.randrange(len(sentences))
            
            tokens_a, tokens_b = token_list[tokens_a_index], token_list[tokens_b_index]
            
            # 拼接A句与B句,格式为:[cls] + A句 + [sep] + B句 + [sep]
            input_ids = [word2idx['[cls]']] + tokens_a + [word2idx['[sep]']] + tokens_b + [word2idx['[sep]']]
            
            # 这里是为了表示两个不同的句子,如A句用0表示,B句用1表示
            segment_ids = [0] * (1 + len(tokens_a) + 1) + [1] * (len(tokens_b) + 1)
        
            # mask lm,15%随机选择token
            n_pred = min(max_pred, max(1, int(len(input_ids) * 0.15))) # 句子中的15%的token
            
            # 15%随机选择的token,去除特殊符号[cls]与[sep]
            cand_maked_pos = [i for i, token in enumerate(input_ids) if token != word2idx['[cls]'] and token != word2idx['[sep]']] # 候选masked 位置
            
            random.shuffle(cand_maked_pos)
            
            # 存储被mask的词的位置与token
            masked_tokens, masked_pos = [], []
            
            # 对input_ids进行mask, 80%的时间用于mask替换,10%的时间随机替换,10%的时间不替换。
            for pos in cand_maked_pos[:n_pred]:
                masked_pos.append(pos) # mask的位置
                masked_tokens.append(input_ids[pos]) # mask的token
                if random.random() < 0.8: # 80%的时间用mask替换
                    input_ids[pos] = word2idx['[mask]']
                elif random.random() > 0.9: # 10%的时间随机替换
                    index = random.randint(0, vocab_size - 1)
                    while index < 5:
                        index = random.randint(0, vocab_size - 1) # 不包含几个特征符号
                    input_ids[pos] = index
                
            # 进行padding,input_ids与segment_ids补齐到最大长度max_len
            n_pad = maxlen - len(input_ids)
            input_ids.extend([0] * n_pad)
            segment_ids.extend([0] * n_pad)
            #print('n_pad:{}'.format(n_pad))
            
            # 不同句子中的mask长度不同,所以需要进行相同长度补齐
            if max_pred > n_pred:
                n_pad = max_pred - n_pred
                masked_tokens.extend([0] * n_pad)
                masked_pos.extend([0] * n_pad)
    
            if ((tokens_a_index + 1) == tokens_b_index) and (positive < (batch_size / 2)):
                batch.append([input_ids, segment_ids, masked_tokens, masked_pos, True]) # isnext
                positive += 1
            elif ((tokens_a_index + 1) != tokens_b_index) and (negative < (batch_size / 2)):
                batch.append([input_ids, segment_ids, masked_tokens, masked_pos, False]) # notnext
                negative += 1
        return batch
    

      

    '''
    构建的数据里除了input_ids,segment_ids外,还有masked_tokens,masked_pos被mask掉的字和其位置(用于bert训练时用),isNext是否为下一句。
    '''
    batch = build_data()
    input_ids, segment_ids, masked_tokens, masked_pos, isNext = zip(*batch)
    
    
    class MyDataset(Data.Dataset):
        def __init__(self, input_ids, segment_ids, masked_tokens, masked_pos, isNext):
            self.input_ids = input_ids
            self.segment_ids = segment_ids
            self.masked_tokens = masked_tokens
            self.masked_pos = masked_pos 
            self.isNext = isNext
            
        def __len__(self):
            return len(self.input_ids)
        
        def __getitem__(self, idx):
            return self.input_ids[idx], self.segment_ids[idx], self.masked_tokens[idx], self.masked_pos[idx], self.isNext[idx]
        
    loader = Data.DataLoader(MyDataset(input_ids, segment_ids, masked_tokens, masked_pos, isNext), batch_size, shuffle=True)
    

     当然这上面还缺了一个信息就是position embeddings,这个可以在bert模型中进行设置,如下:

      #位置信息
      pos = torch.arange(0, src_len).unsqueeze(0).repeat(batch_size, 1)


    参考:https://github.com/graykode
  • 相关阅读:
    298. Binary Tree Longest Consecutive Sequence
    117. Populating Next Right Pointers in Each Node II
    116. Populating Next Right Pointers in Each Node
    163. Missing Ranges
    336. Palindrome Pairs
    727. Minimum Window Subsequence
    211. Add and Search Word
    年底购物狂欢,移动支付安全不容忽视
    成为程序员前需要做的10件事
    全球首推iOS应用防破解技术!
  • 原文地址:https://www.cnblogs.com/little-horse/p/14622047.html
Copyright © 2011-2022 走看看