zoukankan      html  css  js  c++  java
  • Seq2seq 算法

    编码器解码器架构

    image-20211114095630923

    编码器-解码器架构就是构造一个编码器,通过编码器来获得解码器的初始state。这个在架构在很多情况下都可以使用,比如在Seq2Seq算法里。比如在机器翻译领域,通过编码器把要翻译的句子编码为一个初始状态,然后用解码器对这个状态进行解码,解码得到需要的翻译句子。

    class EncodeDecode(nn.Module):
        def __init__(self, Encode, Decode):
            super(EncodeDecode, self).__init__()
            self.encoder = Encode
            self.decoder = Decode
        def forward(self, EncodeX, DecodeX, *args):
            enco_outputs = self.encoder(EncodeX)
            deco_state = self.decoder.init_state(enco_outputs, *args)
            return self.decoder(DecodeX, deco_state)
    

    编码器-解码器架构说白了就是两段RNN + Linear

    我主要跟着做了Seq2seq 算法的实现。

    编码器

    class Encode(nn.Module):
        def __init__(self, vocab_size, embedding_size, hidim, num_layer, dropout=0) -> None:
            super(Encode, self).__init__()
            self.embedding = nn.Embedding(vocab_size, embedding_size)
            self.GRU = nn.GRU(embedding_size, hidim, num_layer, dropout=dropout)
            
        def forward(self, X, *args):
            '''X 的shape应该是(batchsize,时间步)'''
            embedding = self.embedding(X) # 会多出一个维度(时间步, batchsize, embedding_size)
            X = embedding.permute(1, 0, 2) #交换维度
            h, state = self.GRU(X) # 不考虑传入初始state
            return h, state
    

    这里需要注意的embedding,现在都2021年了不要再玩word2vec那一套的,用torch自带的embedding不好吗?

    解码器

    class Decode(nn.Module):
        def __init__(self, vocab_size, embedding_size, hidim, num_layer, dropout) -> None:
            super(Decode, self).__init__()
            self.embedding = nn.Embedding(vocab_size, embedding_size)
            self.GRU = nn.GRU(embedding_size+hidim, hidim, num_layer, dropout=dropout)
            self.Linear = nn.Linear(hidim, vocab_size) 
            
        def init_state(self, enco_ouputs, *args):
            return enco_ouputs[1] # 这就是encode输出的state,不过是因为enco_outputs 是一个元组
        
        def forward(self, X, state):
            '''X shape (batchsize, step)'''
            X = self.embedding(X).permute(1, 0, 2)
            context = state[-1].repeat(X.shape[0], 1, 1)
            X_context = torch.cat((X, context), 2)
            output, state = self.GRU(X_context, state)
            output = self.Linear(output)
            return output.permute(1, 0, 2), state
    
    

    解码器使用编码器输出的state作为初始的state,但是这里还做了一个操作,那就是把编码器输出的state的最后一层的那个state(鬼知道我在说什么,因为解码器是两层GRU架构,输出的state的shape是(2, batchsize, hidim) ) 就是只用最后一层的state,然后把它拼接到解码器输入的X的每一个词上,所以这就是self.GRU = nn.GRU(embedding_size+hidim, hidim, num_layer, dropout=dropout)embedding_size+hidim 的原因。

    掩码交叉熵

    掩码函数

    def sequence_mask(X, valid_len,value=0):
        '''using to get mask crossentropyLoss'''
        '''x tensor shape (batchsize, step, vocab_size)'''
        max_len = X.size(1)
        mask = torch.arange(max_len, device=X.device).reshape(1, -1) < valid_len.reshape(-1, 1)
        X[~mask] = value
        return X
    

    这里的mask用到了广播机制,因为mask是二维的,所以它作用于X的前两个维度。

    掩码交叉熵

       
    class Mask_CrossEntropyLoss(nn.CrossEntropyLoss):
        def forward(self, pred, labels, valid_len):
            '''X shape (batchsize, step, vocabsize)'''
            self.reduction = 'none'
            unweighted_loss = super(Mask_CrossEntropyLoss, self).forward(pred.permute(0, 2, 1), labels) #(batchsize, step)
            weights = torch.ones_like(labels)
            weights = sequence_mask(weights, valid_len)
            weighted_loss = (unweighted_loss * weights).mean(dim=1)
            return weighted_loss # (batchsize)
    

    这里的 unweighted_loss 的shape是 (batchsize, steps),然后把后面pading给用掩码变为0.

    训练

    def train_seq2seq(net, data_iter, num_epochs, lr, num_layer, tgt_vocab, device):
        net.to(device)
        loss = Mask_CrossEntropyLoss()
        optimizer = torch.optim.Adam(net.parameters(), lr=lr)
        net.train()
        for epoch in range(num_epochs):
            timer = d2l.Timer()
            metric = d2l.Accumulator(2) #损失和预测的总共词数
            for batch in data_iter:
                X, X_valid_len, Y, Y_valid_len = [x.to(device) for x in batch]
                bos = torch.tensor([tgt_vocab['<bos>']] * Y.shape[0], device=device).reshape(-1, 1)
                dec_input = torch.concat((bos, Y[:, :-1]), dim=1) # 这里取了Y[:, :-1]是为了保证长度一致。
                y_hat, _ = net(X, dec_input, X_valid_len)
                l = loss(y_hat, Y, Y_valid_len)
                optimizer.zero_grad()
                l.sum().backward()
                nn.utils.clip_grad.clip_grad_norm_(net.parameters(), 1)
                optimizer.step()
                with torch.no_grad():
                    metric.add(l.sum(), sum(Y_valid_len))
            if (epoch + 1)% 50 == 0 or epoch == 0:
                print("epoch {}: loss {}".format(epoch + 1, metric[0]/metric[1]))
        print("loss {}, {} tokens/sec".format(metric[0]/metric[1], metric[1]/timer.stop()))
    

    基本逻辑就是用encode输出state,然后用decode进行预测,在计算交叉熵。训练是时候Decoder是有句子的,也就是label。但是真的进行翻译的时候,decoder是没有句子的,只有起始词<bos>。训练结果:

    image-20211114190124197

    预测

    预测部分,把我整晕了,真的莫名奇妙出现bug。。。直接用书本代码:

    #@save
    def predict_seq2seq(net, src_sentence, src_vocab, tgt_vocab, num_steps,
                        device, save_attention_weights=False):
        """序列到序列模型的预测"""
        src_tokens = src_vocab[src_sentence.lower().split(' ')] + [
            src_vocab['<eos>']]
        enc_valid_len = np.array([len(src_tokens)], ctx=device)
        src_tokens = d2l.truncate_pad(src_tokens, num_steps, src_vocab['<pad>'])
        # 添加批量轴
        enc_X = np.expand_dims(np.array(src_tokens, ctx=device), axis=0)
        enc_outputs = net.encoder(enc_X, enc_valid_len)
        dec_state = net.decoder.init_state(enc_outputs, enc_valid_len)
        # 添加批量轴
        dec_X = np.expand_dims(np.array([tgt_vocab['<bos>']], ctx=device), axis=0)
        output_seq, attention_weight_seq = [], []
        for _ in range(num_steps):
            Y, dec_state = net.decoder(dec_X, dec_state)
            # 我们使用具有预测最高可能性的词元,作为解码器在下一时间步的输入
            dec_X = Y.argmax(axis=2)
            pred = dec_X.squeeze(axis=0).astype('int32').item()
            # 保存注意力权重(稍后讨论)
            if save_attention_weights:
                attention_weight_seq.append(net.decoder.attention_weights)
            # 一旦序列结束词元被预测,输出序列的生成就完成了
            if pred == tgt_vocab['<eos>']:
                break
            output_seq.append(pred)
        return ' '.join(tgt_vocab.to_tokens(output_seq)), attention_weight_seq
    

    翻译准确度用BLEU分数来衡量:

    \[\exp \left(\min \left(0,1-\frac{l e n_{\text {label }}}{l e n_{\text {pred }}}\right)\right) \prod_{n=1}^{k} p_{n}^{1 / 2^{n}} \]

    def bleu(pred_seq, label_seq, k):  #@save
        """计算 BLEU"""
        pred_tokens, label_tokens = pred_seq.split(' '), label_seq.split(' ')
        len_pred, len_label = len(pred_tokens), len(label_tokens)
        score = math.exp(min(0, 1 - len_label / len_pred))
        for n in range(1, k + 1):
            num_matches, label_subs = 0, collections.defaultdict(int)
            for i in range(len_label - n + 1):
                label_subs[''.join(label_tokens[i: i + n])] += 1
            for i in range(len_pred - n + 1):
                if label_subs[''.join(pred_tokens[i: i + n])] > 0:
                    num_matches += 1
                    label_subs[''.join(pred_tokens[i: i + n])] -= 1
            score *= math.pow(num_matches / (len_pred - n + 1), math.pow(0.5, n))
        return score
    

    翻译:

    engs = ['go .', "i lost .", 'he\'s calm .', 'i\'m home .']
    fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
    for eng, fra in zip(engs, fras):
        translation, attention_weight_seq = predict_seq2seq(
            net, eng, src_vocab, tgt_vocab, num_steps, device)
        print(f'{eng} => {translation}, bleu {bleu(translation, fra, k=2):.3f}')
    

    最后真的搞,我直接人晕了。。。

    下次有机会再好好看看seq2seq 吧,不过还是那句话,2021年谁还用seq2seq啊。。。

  • 相关阅读:
    ssh框架下 写简单的hql语句
    onclick事件 在使用模板填充情况下 向后台传递多值
    调用 sendResponseMsg 遇到的问题
    ERP项目有关时间的修改和查看的显示,去掉时分秒
    ERP中select的填充方法
    最简单的jQuery ajax请求
    ERP中默认申请人和申请部门
    list 按元素的某字段排序方法。作者:黄欣
    C# 对象、文件与二进制串(byte数组)之间的转换【转载】
    .net framework(4.6.2) 迁移 .net core(2.2) 总结
  • 原文地址:https://www.cnblogs.com/kalicener/p/15552858.html
Copyright © 2011-2022 走看看