zoukankan      html  css  js  c++  java
  • [课堂笔记][pytorch学习][7]Seq2Seq

    斯坦福公开课

    论文

    PyTorch代码(学习代码编写方式)

    更多关于Machine Translation

    • Beam Search
    • Pointer network 文本摘要
    • Copy Mechanism 文本摘要
    • Converage Loss 
    • ConvSeq2Seq
    • Transformer
    • Tensor2Tensor

    TODO

    • 建议同学尝试对中文进行分词

    NER

    部分代码

    读入中英文数据

    • 英文我们使用nltk的word tokenizer来分词,并且使用小写字母
    • 中文我们直接使用单个汉字作为基本单元
    import os
    import sys
    import math
    from collections import Counter
    import numpy as np
    import random
    
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    
    import nltk
    
    def load_data(in_file):
        cn = []
        en = []
        num_examples = 0
        with open(in_file, 'r') as f:
            for line in f:
                line = line.strip().split("	")
                
                en.append(["BOS"] + nltk.word_tokenize(line[0].lower()) + ["EOS"])
                # split chinese sentence into characters
                cn.append(["BOS"] + [c for c in line[1]] + ["EOS"])
        return en, cn
    
    train_file = "nmt/en-cn/train.txt"
    dev_file = "nmt/en-cn/dev.txt"
    train_en, train_cn = load_data(train_file)
    dev_en, dev_cn = load_data(dev_file)

    encoder/attention/decoder

    class Encoder(nn.Module):
        def __init__(self, vocab_size, embed_size, enc_hidden_size, dec_hidden_size, dropout=0.2):
            super(Encoder, self).__init__()
            self.embed = nn.Embedding(vocab_size, embed_size)
            self.rnn = nn.GRU(embed_size, enc_hidden_size, batch_first=True, bidirectional=True)
            self.dropout = nn.Dropout(dropout)
            self.fc = nn.Linear(enc_hidden_size * 2, dec_hidden_size)
    
        def forward(self, x, lengths):
            sorted_len, sorted_idx = lengths.sort(0, descending=True)
            x_sorted = x[sorted_idx.long()]
            embedded = self.dropout(self.embed(x_sorted))
            
            packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, sorted_len.long().cpu().data.numpy(), batch_first=True)
            packed_out, hid = self.rnn(packed_embedded)
            out, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True)
            _, original_idx = sorted_idx.sort(0, descending=False)
            out = out[original_idx.long()].contiguous()
            hid = hid[:, original_idx.long()].contiguous()
            
            hid = torch.cat([hid[-2], hid[-1]], dim=1)
            hid = torch.tanh(self.fc(hid)).unsqueeze(0)
    
            return out, hid
    
    class Attention(nn.Module):
        def __init__(self, enc_hidden_size, dec_hidden_size):
            super(Attention, self).__init__()
    
            self.enc_hidden_size = enc_hidden_size
            self.dec_hidden_size = dec_hidden_size
    
            self.linear_in = nn.Linear(enc_hidden_size*2, dec_hidden_size, bias=False)
            self.linear_out = nn.Linear(enc_hidden_size*2 + dec_hidden_size, dec_hidden_size)
            
        def forward(self, output, context, mask):
            # output: batch_size, output_len, dec_hidden_size
            # context: batch_size, context_len, 2*enc_hidden_size
        
            batch_size = output.size(0)
            output_len = output.size(1)
            input_len = context.size(1)
            
            context_in = self.linear_in(context.view(batch_size*input_len, -1)).view(                
                batch_size, input_len, -1) # batch_size, context_len, dec_hidden_size
            
            # context_in.transpose(1,2): batch_size, dec_hidden_size, context_len 
            # output: batch_size, output_len, dec_hidden_size
            attn = torch.bmm(output, context_in.transpose(1,2)) 
            # batch_size, output_len, context_len
    
            attn.data.masked_fill(mask, -1e6)
    
            attn = F.softmax(attn, dim=2) 
            # batch_size, output_len, context_len
    
            context = torch.bmm(attn, context) 
            # batch_size, output_len, enc_hidden_size
            
            output = torch.cat((context, output), dim=2) # batch_size, output_len, hidden_size*2
    
            output = output.view(batch_size*output_len, -1)
            output = torch.tanh(self.linear_out(output))
            output = output.view(batch_size, output_len, -1)
            return output, attn
    
    
    class Decoder(nn.Module):
        def __init__(self, vocab_size, embed_size, enc_hidden_size, dec_hidden_size, dropout=0.2):
            super(Decoder, self).__init__()
            self.embed = nn.Embedding(vocab_size, embed_size)
            self.attention = Attention(enc_hidden_size, dec_hidden_size)
            self.rnn = nn.GRU(embed_size, hidden_size, batch_first=True)
            self.out = nn.Linear(dec_hidden_size, vocab_size)
            self.dropout = nn.Dropout(dropout)
    
        def create_mask(self, x_len, y_len):
            # a mask of shape x_len * y_len
            device = x_len.device
            max_x_len = x_len.max()
            max_y_len = y_len.max()
            x_mask = torch.arange(max_x_len, device=x_len.device)[None, :] < x_len[:, None]
            y_mask = torch.arange(max_y_len, device=x_len.device)[None, :] < y_len[:, None]
            mask = (1 - x_mask[:, :, None] * y_mask[:, None, :]).byte()
            return mask
        
        def forward(self, ctx, ctx_lengths, y, y_lengths, hid):
            sorted_len, sorted_idx = y_lengths.sort(0, descending=True)
            y_sorted = y[sorted_idx.long()]
            hid = hid[:, sorted_idx.long()]
            
            y_sorted = self.dropout(self.embed(y_sorted)) # batch_size, output_length, embed_size
    
            packed_seq = nn.utils.rnn.pack_padded_sequence(y_sorted, sorted_len.long().cpu().data.numpy(), batch_first=True)
            out, hid = self.rnn(packed_seq, hid)
            unpacked, _ = nn.utils.rnn.pad_packed_sequence(out, batch_first=True)
            _, original_idx = sorted_idx.sort(0, descending=False)
            output_seq = unpacked[original_idx.long()].contiguous()
            hid = hid[:, original_idx.long()].contiguous()
    
            mask = self.create_mask(y_lengths, ctx_lengths)
    
            output, attn = self.attention(output_seq, ctx, mask)
            output = F.log_softmax(self.out(output), -1)
            
            return output, hid, attn

    seq2seq

    class Seq2Seq(nn.Module):
        def __init__(self, encoder, decoder):
            super(Seq2Seq, self).__init__()
            self.encoder = encoder
            self.decoder = decoder
            
        def forward(self, x, x_lengths, y, y_lengths):
            encoder_out, hid = self.encoder(x, x_lengths)
            output, hid, attn = self.decoder(ctx=encoder_out, 
                        ctx_lengths=x_lengths,
                        y=y,
                        y_lengths=y_lengths,
                        hid=hid)
            return output, attn
        
        def translate(self, x, x_lengths, y, max_length=100):
            encoder_out, hid = self.encoder(x, x_lengths)
            preds = []
            batch_size = x.shape[0]
            attns = []
            for i in range(max_length):
                output, hid, attn = self.decoder(ctx=encoder_out, 
                        ctx_lengths=x_lengths,
                        y=y,
                        y_lengths=torch.ones(batch_size).long().to(y.device),
                        hid=hid)
                y = output.max(2)[1].view(batch_size, 1)
                preds.append(y)
                attns.append(attn)
            return torch.cat(preds, 1), torch.cat(attns, 1)
  • 相关阅读:
    Java中的基本数据类型以及自增特性总结
    mysql菜鸟
    Typora使用教程
    net core下链路追踪skywalking安装和简单使用
    netcore5下ocelot网关简单使用
    netcore热插拔dll
    快速排序
    netcore5下js请求跨域
    SpringBoot接口防刷
    EL 表达式
  • 原文地址:https://www.cnblogs.com/nakkk/p/14988218.html
Copyright © 2011-2022 走看看