zoukankan      html  css  js  c++  java
  • Pytorch-seq2seq机器翻译模型(不含attention和含attention两个版本)

    由于语料短,训练时间也短,模型性能不好,以下演示过程。

    语料链接:https://pan.baidu.com/s/1wpP4t_GSyPAD6HTsIoGPZg
    提取码:jqq8

    数据格式如图(先英文,再空格,再繁体中文):

    以下代码运行在Google Colab上。 

    导包:

     1 import os
     2 import sys
     3 import math
     4 from collections import Counter
     5 import numpy as np
     6 import random
     7 
     8 import torch
     9 import torch.nn as nn
    10 import torch.nn.functional as F
    11 
    12 import nltk
    13 nltk.download('punkt')

    1.数据预处理

    1.1读入中英文数据

    • 英文使用nltk的word tokenizer来分词,并且使用小写字母
    • 中文直接使用单个汉字作为基本单元
     1 def load_data(in_file):
     2     cn = []
     3     en = []
     4     num_examples = 0
     5     with open(in_file, 'r') as f:
     6         for line in f:
     7             line = line.strip().split("	")
     8             
     9             en.append(["BOS"] + nltk.word_tokenize(line[0].lower()) + ["EOS"])
    10             cn.append(["BOS"] + [c for c in line[1]] + ["EOS"])
    11     return en, cn
    12 
    13 train_file = "nmt/en-cn/train.txt"
    14 dev_file = "nmt/en-cn/dev.txt"
    15 train_en, train_cn = load_data(train_file)
    16 dev_en, dev_cn = load_data(dev_file)

    查看返回的数据内容:

    1 print(dev_en[:2])
    2 print(dev_cn[:2])

    [['BOS', 'she', 'put', 'the', 'magazine', 'on', 'the', 'table', '.', 'EOS'], ['BOS', 'hey', ',', 'what', 'are', 'you', 'doing', 'here', '?', 'EOS']]

    [['BOS', '她', '把', '雜', '誌', '放', '在', '桌', '上', '。', 'EOS'], ['BOS', '嘿', ',', '你', '在', '這', '做', '什', '麼', '?', 'EOS']]

    1.2构建单词表

     1 UNK_IDX = 0
     2 PAD_IDX = 1
     3 def build_dict(sentences, max_words=50000):
     4     word_count = Counter()
     5     for sentence in sentences:
     6         for s in sentence:
     7             word_count[s] += 1
     8     ls = word_count.most_common(max_words)
     9     total_words = len(ls) + 2
    10     word_dict = {w[0]: index+2 for index, w in enumerate(ls)}
    11     word_dict["UNK"] = UNK_IDX
    12     word_dict["PAD"] = PAD_IDX
    13     return word_dict, total_words      #total_words所有单词数,最大50002
    14 
    15 en_dict, en_total_words = build_dict(train_en)
    16 cn_dict, cn_total_words = build_dict(train_cn)
    17 inv_en_dict = {v: k for k, v in en_dict.items()}    #英文:索引到单词
    18 inv_cn_dict = {v: k for k, v in cn_dict.items()}    #中文:索引到字

    1.3把单词全部转变成数字

    sort_by_len=True的目的是为了使得一个batch中的句子长度差不多,所以按长度排序。

     1 def encode(en_sentences, cn_sentences, en_dict, cn_dict, sort_by_len=True):        
     2 
     3     length = len(en_sentences)
     4     out_en_sentences = [[en_dict.get(w, 0) for w in sent] for sent in en_sentences]
     5     out_cn_sentences = [[cn_dict.get(w, 0) for w in sent] for sent in cn_sentences]
     6 
     7     # sort sentences by word lengths
     8     def len_argsort(seq):
     9         return sorted(range(len(seq)), key=lambda x: len(seq[x]))
    10        
    11     # 把中文和英文按照同样的顺序排序
    12     if sort_by_len:
    13         sorted_index = len_argsort(out_en_sentences)
    14         out_en_sentences = [out_en_sentences[i] for i in sorted_index]
    15         out_cn_sentences = [out_cn_sentences[i] for i in sorted_index]
    16         
    17     return out_en_sentences, out_cn_sentences
    18 
    19 train_en, train_cn = encode(train_en, train_cn, en_dict, cn_dict)
    20 dev_en, dev_cn = encode(dev_en, dev_cn, en_dict, cn_dict)

    查看返回的数据内容:

    1 print(train_cn[2])
    2 print([inv_cn_dict[i] for i in train_cn[2]])
    3 print([inv_en_dict[i] for i in train_en[2]])

    [2, 982, 2028, 8, 4, 3]

    ['BOS', '祝', '贺', '你', '。', 'EOS']

    ['BOS', 'congratulations', '!', 'EOS']

    1.4把全部句子分成batch

    1 def get_minibatches(n, minibatch_size, shuffle=True):  #n是传进来的句子数
    2     idx_list = np.arange(0, n, minibatch_size)   #[0, 1, ..., n-1]按minibatch_size大小分割
    3     if shuffle:
    4         np.random.shuffle(idx_list)
    5     minibatches = []
    6     for idx in idx_list:
    7         minibatches.append(np.arange(idx, min(idx + minibatch_size, n)))
    8     return minibatches

    查看上面函数的功能:

    1 get_minibatches(100, 15)
    2 [array([60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74]),
    3  array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14]),
    4  array([75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89]),
    5  array([45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59]),
    6  array([30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44]),
    7  array([90, 91, 92, 93, 94, 95, 96, 97, 98, 99]),
    8  array([15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29])]
     1 def prepare_data(seqs):   #seqs传入的是minibatches中的一个minibatch对应的batch_size个句子索引(嵌套列表),此处batch_size=64
     2 
     3     lengths = [len(seq) for seq in seqs]   
     4     n_samples = len(seqs)            
     5     max_len = np.max(lengths)  #batch_size个句子中最长句子长度
     6 
     7     x = np.zeros((n_samples, max_len)).astype('int32')
     8     x_lengths = np.array(lengths).astype("int32")
     9     for idx, seq in enumerate(seqs):
    10         x[idx, :lengths[idx]] = seq
    11     return x, x_lengths             
    12 
    13 def gen_examples(en_sentences, cn_sentences, batch_size):
    14     minibatches = get_minibatches(len(en_sentences), batch_size)
    15     all_ex = []
    16     for minibatch in minibatches:
    17         mb_en_sentences = [en_sentences[t] for t in minibatch]
    18         mb_cn_sentences = [cn_sentences[t] for t in minibatch]
    19         mb_x, mb_x_len = prepare_data(mb_en_sentences)
    20         mb_y, mb_y_len = prepare_data(mb_cn_sentences)
    21         all_ex.append((mb_x, mb_x_len, mb_y, mb_y_len))
    22     return all_ex     #返回内容依次是batch_size个英文句子索引,英文句子长度,中文句子索引,中文句子长度
    23 
    24 batch_size = 64
    25 train_data = gen_examples(train_en, train_cn, batch_size)
    26 dev_data = gen_examples(dev_en, dev_cn, batch_size)

    2.Encoder Decoder模型(没有Attention版本)

    2.1定义计算损失的函数

     1 # masked cross entropy loss
     2 class LanguageModelCriterion(nn.Module):
     3     def __init__(self):
     4         super(LanguageModelCriterion, self).__init__()
     5 
     6     def forward(self, input, target, mask):   #把mask的部分忽略掉
     7         # input: (batch_size * seq_len) * vocab_size
     8         input = input.contiguous().view(-1, input.size(2))
     9         # target: batch_size * 1
    10         target = target.contiguous().view(-1, 1)
    11         mask = mask.contiguous().view(-1, 1)
    12         output = -input.gather(1, target) * mask
    13         output = torch.sum(output) / torch.sum(mask)
    14 
    15         return output

    2.2Encoder部分

    Encoder模型的任务是把输入文字传入embedding层和GRU层,转换成一些hidden states作为后续的context vectors;

    对nn.utils.rnn.pack_padded_sequence和nn.utils.rnn.pad_packed_sequence的理解:http://www.mamicode.com/info-detail-2493083.html

     1 class PlainEncoder(nn.Module):
     2     def __init__(self, vocab_size, hidden_size, dropout=0.2):       #假设embedding_size=hidden_size
     3         super(PlainEncoder, self).__init__()
     4         self.embed = nn.Embedding(vocab_size, hidden_size)
     5         self.rnn = nn.GRU(hidden_size, hidden_size, batch_first=True)
     6         self.dropout = nn.Dropout(dropout)
     7 
     8     def forward(self, x, lengths):   #最后一个hidden_state要取出来作为context vector,所以需要lengths
     9         sorted_len, sorted_idx = lengths.sort(0, descending=True)   #把batch里面的seq按照长度降序排列
    10         x_sorted = x[sorted_idx.long()]
    11         embedded = self.dropout(self.embed(x_sorted))
    12         
    13         #句子padding到一样长度的(真实句长会比padding的短),为了rnn时能取到真实长度的最后状态,先pack_padded_sequence进行处理
    14         packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, sorted_len.long().cpu().data.numpy(), batch_first=True)
    15         packed_out, hid = self.rnn(packed_embedded)
    16         out, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True)    #回到padding长度
    17         
    18         _, original_idx = sorted_idx.sort(0, descending=False)                     #排序回原来的样子
    19         out = out[original_idx.long()].contiguous()
    20         hid = hid[:, original_idx.long()].contiguous()
    21         
    22         return out, hid[[-1]]   #hid[[-1]]相当于out[:, -1]

    2.3Decoder部分

    Decoder会根据已经翻译的句子内容和context vectors,来决定下一个输出的单词;

     1 class PlainDecoder(nn.Module):
     2     def __init__(self, vocab_size, hidden_size, dropout=0.2):
     3         super(PlainDecoder, self).__init__()
     4         self.embed = nn.Embedding(vocab_size, hidden_size)
     5         self.rnn = nn.GRU(hidden_size, hidden_size, batch_first=True)
     6         self.fc = nn.Linear(hidden_size, vocab_size)
     7         self.dropout = nn.Dropout(dropout)
     8         
     9     def forward(self, y, y_lengths, hid):    #和PlainEncoder的forward过程大致差不多,区别在于hidden_state不是0而是传入的
    10         sorted_len, sorted_idx = y_lengths.sort(0, descending=True)
    11         y_sorted = y[sorted_idx.long()]
    12         hid = hid[:, sorted_idx.long()]
    13 
    14         y_sorted = self.dropout(self.embed(y_sorted))             #[batch_size, y_lengths, embed_size=hidden_size]
    15         
    16         packed_seq = nn.utils.rnn.pack_padded_sequence(y_sorted, sorted_len.long().cpu().data.numpy(), batch_first=True)
    17         out, hid = self.rnn(packed_seq, hid)
    18         unpacked, _ = nn.utils.rnn.pad_packed_sequence(out, batch_first=True)
    19 
    20         _, original_idx = sorted_idx.sort(0, descending=False)
    21         output_seq = unpacked[original_idx.long()].contiguous()   #[batch_size, y_lengths, hidden_size]
    22         hid = hid[:, original_idx.long()].contiguous()            #[1, batch_size, hidden_size]
    23 
    24         output = F.log_softmax(self.fc(output_seq), -1)           #[batch_size, y_lengths, vocab_size]
    25         
    26         return output, hid

    2.4构建Seq2Seq模型

    构建Seq2Seq模型把encoder, attention, decoder串到一起;

     1 class PlainSeq2Seq(nn.Module):
     2     def __init__(self, encoder, decoder):
     3         super(PlainSeq2Seq, self).__init__()
     4         self.encoder = encoder
     5         self.decoder = decoder
     6         
     7     def forward(self, x, x_lengths, y, y_lengths):
     8         encoder_out, hid = self.encoder(x, x_lengths)
     9         output, hid = self.decoder(y, y_lengths, hid)
    10         return output, None
    11     
    12     def translate(self, x, x_lengths, y, max_length=10):
    13         encoder_out, hid = self.encoder(x, x_lengths)
    14         preds = []
    15         batch_size = x.shape[0]
    16         attns = []
    17         for i in range(max_length):
    18             output, hid = self.decoder(y=y, y_lengths=torch.ones(batch_size).long().to(y.device), hid=hid)
    19             y = output.max(2)[1].view(batch_size, 1)
    20             preds.append(y)
    21             
    22         return torch.cat(preds, 1), None

    2.5定义损失函数

     1 # masked cross entropy loss
     2 class LanguageModelCriterion(nn.Module):
     3     def __init__(self):
     4         super(LanguageModelCriterion, self).__init__()
     5 
     6     def forward(self, input, target, mask):
     7         # input: (batch_size * seq_len) * vocab_size
     8         input = input.contiguous().view(-1, input.size(2))
     9         # target: batch_size * 1
    10         target = target.contiguous().view(-1, 1)
    11         mask = mask.contiguous().view(-1, 1)
    12         output = -input.gather(1, target) * mask
    13         output = torch.sum(output) / torch.sum(mask)
    14 
    15         return output

    3.创建模型

    定义模型、损失、优化器。

    1 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    2 dropout = 0.2
    3 hidden_size = 100
    4 encoder = PlainEncoder(vocab_size=en_total_words, hidden_size=hidden_size, dropout=dropout)
    5 decoder = PlainDecoder(vocab_size=cn_total_words, hidden_size=hidden_size, dropout=dropout)
    6 model = PlainSeq2Seq(encoder, decoder)
    7 model = model.to(device)
    8 loss_fn = LanguageModelCriterion().to(device)
    9 optimizer = torch.optim.Adam(model.parameters())

    4.评估模型

     1 def evaluate(model, data):
     2     model.eval()
     3     total_num_words = total_loss = 0.
     4     with torch.no_grad():
     5         for it, (mb_x, mb_x_len, mb_y, mb_y_len) in enumerate(data):
     6             mb_x = torch.from_numpy(mb_x).to(device).long()
     7             mb_x_len = torch.from_numpy(mb_x_len).to(device).long()
     8             mb_input = torch.from_numpy(mb_y[:, :-1]).to(device).long()
     9             mb_output = torch.from_numpy(mb_y[:, 1:]).to(device).long()
    10             mb_y_len = torch.from_numpy(mb_y_len-1).to(device).long()
    11             mb_y_len[mb_y_len<=0] = 1
    12 
    13             mb_pred, attn = model(mb_x, mb_x_len, mb_input, mb_y_len)
    14 
    15             mb_out_mask = torch.arange(mb_y_len.max().item(), device=device)[None, :] < mb_y_len[:, None]
    16             mb_out_mask = mb_out_mask.float()
    17 
    18             loss = loss_fn(mb_pred, mb_output, mb_out_mask)
    19 
    20             num_words = torch.sum(mb_y_len).item()
    21             total_loss += loss.item() * num_words
    22             total_num_words += num_words
    23     print("Evaluation loss", total_loss/total_num_words)

    5.训练模型

     1 def train(model, data, num_epochs=20):
     2     for epoch in range(num_epochs):
     3         model.train()
     4         total_num_words = total_loss = 0.
     5         for it, (mb_x, mb_x_len, mb_y, mb_y_len) in enumerate(data):
     6             mb_x = torch.from_numpy(mb_x).to(device).long()
     7             mb_x_len = torch.from_numpy(mb_x_len).to(device).long()
     8             mb_input = torch.from_numpy(mb_y[:, :-1]).to(device).long()
     9             mb_output = torch.from_numpy(mb_y[:, 1:]).to(device).long()
    10             mb_y_len = torch.from_numpy(mb_y_len-1).to(device).long()
    11             mb_y_len[mb_y_len<=0] = 1
    12             
    13             mb_pred, attn = model(mb_x, mb_x_len, mb_input, mb_y_len)
    14             
    15             mb_out_mask = torch.arange(mb_y_len.max().item(), device=device)[None, :] < mb_y_len[:, None]
    16             mb_out_mask = mb_out_mask.float()
    17             
    18             loss = loss_fn(mb_pred, mb_output, mb_out_mask)
    19             
    20             num_words = torch.sum(mb_y_len).item()
    21             total_loss += loss.item() * num_words
    22             total_num_words += num_words
    23             
    24             # 更新模型
    25             optimizer.zero_grad()
    26             loss.backward()
    27             torch.nn.utils.clip_grad_norm_(model.parameters(), 5.)
    28             optimizer.step()
    29             
    30             if it % 100 == 0:
    31                 print("Epoch", epoch, "iteration", it, "loss", loss.item())
    32 
    33                 
    34         print("Epoch", epoch, "Training loss", total_loss/total_num_words)
    35         if epoch % 5 == 0:
    36             evaluate(model, dev_data)

    训练100次:

    1 train(model, train_data, num_epochs=100)

    训练结果(training loss在不断下降):

      1 Epoch 0 iteration 0 loss 8.084440231323242
      2 Epoch 0 iteration 100 loss 4.8448944091796875
      3 Epoch 0 iteration 200 loss 4.879772663116455
      4 Epoch 0 Training loss 5.477221919210141
      5 Evaluation loss 4.821030395389826
      6 Epoch 1 iteration 0 loss 4.69868278503418
      7 Epoch 1 iteration 100 loss 4.085171699523926
      8 Epoch 1 iteration 200 loss 4.312857151031494
      9 Epoch 1 Training loss 4.579521701350524
     10 Epoch 2 iteration 0 loss 4.193971633911133
     11 Epoch 2 iteration 100 loss 3.678673267364502
     12 Epoch 2 iteration 200 loss 4.019515514373779
     13 Epoch 2 Training loss 4.186071368925457
     14 Epoch 3 iteration 0 loss 3.8352835178375244
     15 Epoch 3 iteration 100 loss 3.3954527378082275
     16 Epoch 3 iteration 200 loss 3.774580240249634
     17 Epoch 3 Training loss 3.9222166424267986
     18 Epoch 4 iteration 0 loss 3.585063934326172
     19 Epoch 4 iteration 100 loss 3.215750217437744
     20 Epoch 4 iteration 200 loss 3.626997232437134
     21 Epoch 4 Training loss 3.722608096150466
     22 Epoch 5 iteration 0 loss 3.411375045776367
     23 Epoch 5 iteration 100 loss 3.0424859523773193
     24 Epoch 5 iteration 200 loss 3.492255926132202
     25 Epoch 5 Training loss 3.5699179079587195
     26 Evaluation loss 3.655821240952787
     27 Epoch 6 iteration 0 loss 3.273927927017212
     28 Epoch 6 iteration 100 loss 2.897022247314453
     29 Epoch 6 iteration 200 loss 3.355715036392212
     30 Epoch 6 Training loss 3.4411540739967426
     31 Epoch 7 iteration 0 loss 3.16508412361145
     32 Epoch 7 iteration 100 loss 2.7818763256073
     33 Epoch 7 iteration 200 loss 3.241000175476074
     34 Epoch 7 Training loss 3.330995073153501
     35 Epoch 8 iteration 0 loss 3.081458806991577
     36 Epoch 8 iteration 100 loss 2.692844867706299
     37 Epoch 8 iteration 200 loss 3.159105062484741
     38 Epoch 8 Training loss 3.237538761219645
     39 Epoch 9 iteration 0 loss 2.983361005783081
     40 Epoch 9 iteration 100 loss 2.5852301120758057
     41 Epoch 9 iteration 200 loss 3.076793670654297
     42 Epoch 9 Training loss 3.1542968146839754
     43 Epoch 10 iteration 0 loss 2.88155198097229
     44 Epoch 10 iteration 100 loss 2.504387617111206
     45 Epoch 10 iteration 200 loss 2.9708898067474365
     46 Epoch 10 Training loss 3.0766581801071924
     47 Evaluation loss 3.3804360915245204
     48 Epoch 11 iteration 0 loss 2.805739164352417
     49 Epoch 11 iteration 100 loss 2.417832612991333
     50 Epoch 11 iteration 200 loss 2.9001076221466064
     51 Epoch 11 Training loss 3.0072335865815747
     52 Epoch 12 iteration 0 loss 2.7389864921569824
     53 Epoch 12 iteration 100 loss 2.352132558822632
     54 Epoch 12 iteration 200 loss 2.864527702331543
     55 Epoch 12 Training loss 2.945309993148362
     56 Epoch 13 iteration 0 loss 2.6841001510620117
     57 Epoch 13 iteration 100 loss 2.2722346782684326
     58 Epoch 13 iteration 200 loss 2.8002915382385254
     59 Epoch 13 Training loss 2.8879525671218156
     60 Epoch 14 iteration 0 loss 2.641491651535034
     61 Epoch 14 iteration 100 loss 2.237807273864746
     62 Epoch 14 iteration 200 loss 2.7538034915924072
     63 Epoch 14 Training loss 2.833802188663957
     64 Epoch 15 iteration 0 loss 2.5613601207733154
     65 Epoch 15 iteration 100 loss 2.149299144744873
     66 Epoch 15 iteration 200 loss 2.671037435531616
     67 Epoch 15 Training loss 2.7850014679518598
     68 Evaluation loss 3.2569677577366516
     69 Epoch 16 iteration 0 loss 2.5330140590667725
     70 Epoch 16 iteration 100 loss 2.0988974571228027
     71 Epoch 16 iteration 200 loss 2.611022472381592
     72 Epoch 16 Training loss 2.7354116963192716
     73 Epoch 17 iteration 0 loss 2.485084295272827
     74 Epoch 17 iteration 100 loss 2.0532665252685547
     75 Epoch 17 iteration 200 loss 2.604226589202881
     76 Epoch 17 Training loss 2.6934350694497957
     77 Epoch 18 iteration 0 loss 2.4521820545196533
     78 Epoch 18 iteration 100 loss 2.0395381450653076
     79 Epoch 18 iteration 200 loss 2.5578808784484863
     80 Epoch 18 Training loss 2.651303096776386
     81 Epoch 19 iteration 0 loss 2.390338182449341
     82 Epoch 19 iteration 100 loss 1.9780246019363403
     83 Epoch 19 iteration 200 loss 2.5150232315063477
     84 Epoch 19 Training loss 2.611681331448251
     85 Epoch 20 iteration 0 loss 2.352649211883545
     86 Epoch 20 iteration 100 loss 1.9426053762435913
     87 Epoch 20 iteration 200 loss 2.4782586097717285
     88 Epoch 20 Training loss 2.5747013451744616
     89 Evaluation loss 3.194680030596711
     90 Epoch 21 iteration 0 loss 2.3205008506774902
     91 Epoch 21 iteration 100 loss 1.9143742322921753
     92 Epoch 21 iteration 200 loss 2.4607479572296143
     93 Epoch 21 Training loss 2.5404243457594116
     94 Epoch 22 iteration 0 loss 2.3100969791412354
     95 Epoch 22 iteration 100 loss 1.912932276725769
     96 Epoch 22 iteration 200 loss 2.4103682041168213
     97 Epoch 22 Training loss 2.507626390779296
     98 Epoch 23 iteration 0 loss 2.228956699371338
     99 Epoch 23 iteration 100 loss 1.8543353080749512
    100 Epoch 23 iteration 200 loss 2.3663489818573
    101 Epoch 23 Training loss 2.475231424650597
    102 Epoch 24 iteration 0 loss 2.199277639389038
    103 Epoch 24 iteration 100 loss 1.8272788524627686
    104 Epoch 24 iteration 200 loss 2.3518714904785156
    105 Epoch 24 Training loss 2.4439996520576863
    106 Epoch 25 iteration 0 loss 2.198460817337036
    107 Epoch 25 iteration 100 loss 1.7921738624572754
    108 Epoch 25 iteration 200 loss 2.3299384117126465
    109 Epoch 25 Training loss 2.416539151404694
    110 Evaluation loss 3.1583419660450347
    111 Epoch 26 iteration 0 loss 2.1647706031799316
    112 Epoch 26 iteration 100 loss 1.725657343864441
    113 Epoch 26 iteration 200 loss 2.268852710723877
    114 Epoch 26 Training loss 2.3919890312051444
    115 Epoch 27 iteration 0 loss 2.1400880813598633
    116 Epoch 27 iteration 100 loss 1.7474910020828247
    117 Epoch 27 iteration 200 loss 2.256742000579834
    118 Epoch 27 Training loss 2.3595162004913086
    119 Epoch 28 iteration 0 loss 2.0979115962982178
    120 Epoch 28 iteration 100 loss 1.7000322341918945
    121 Epoch 28 iteration 200 loss 2.2546005249023438
    122 Epoch 28 Training loss 2.3335356415568618
    123 Epoch 29 iteration 0 loss 2.1031572818756104
    124 Epoch 29 iteration 100 loss 1.6599613428115845
    125 Epoch 29 iteration 200 loss 2.2020833492279053
    126 Epoch 29 Training loss 2.311978717884133
    127 Epoch 30 iteration 0 loss 2.041980028152466
    128 Epoch 30 iteration 100 loss 1.6663353443145752
    129 Epoch 30 iteration 200 loss 2.1463098526000977
    130 Epoch 30 Training loss 2.2902015222655807
    131 Evaluation loss 3.133273747140961
    132 Epoch 31 iteration 0 loss 2.0045719146728516
    133 Epoch 31 iteration 100 loss 1.6515719890594482
    134 Epoch 31 iteration 200 loss 2.1130664348602295
    135 Epoch 31 Training loss 2.2633183437027657
    136 Epoch 32 iteration 0 loss 1.9948643445968628
    137 Epoch 32 iteration 100 loss 1.6262538433074951
    138 Epoch 32 iteration 200 loss 2.1329450607299805
    139 Epoch 32 Training loss 2.242057023454951
    140 Epoch 33 iteration 0 loss 1.9623773097991943
    141 Epoch 33 iteration 100 loss 1.6022558212280273
    142 Epoch 33 iteration 200 loss 2.092766523361206
    143 Epoch 33 Training loss 2.219300144243463
    144 Epoch 34 iteration 0 loss 1.929176688194275
    145 Epoch 34 iteration 100 loss 1.57985258102417
    146 Epoch 34 iteration 200 loss 2.067972183227539
    147 Epoch 34 Training loss 2.199957146669663
    148 Epoch 35 iteration 0 loss 1.9449653625488281
    149 Epoch 35 iteration 100 loss 1.5760831832885742
    150 Epoch 35 iteration 200 loss 2.056731939315796
    151 Epoch 35 Training loss 2.1790822226814464
    152 Evaluation loss 3.13363336627263
    153 Epoch 36 iteration 0 loss 1.8961074352264404
    154 Epoch 36 iteration 100 loss 1.5195672512054443
    155 Epoch 36 iteration 200 loss 2.0268213748931885
    156 Epoch 36 Training loss 2.160204240618562
    157 Epoch 37 iteration 0 loss 1.9172203540802002
    158 Epoch 37 iteration 100 loss 1.495902180671692
    159 Epoch 37 iteration 200 loss 1.9827772378921509
    160 Epoch 37 Training loss 2.139063811380212
    161 Epoch 38 iteration 0 loss 1.8988227844238281
    162 Epoch 38 iteration 100 loss 1.5224453210830688
    163 Epoch 38 iteration 200 loss 1.972291111946106
    164 Epoch 38 Training loss 2.1211086652629887
    165 Epoch 39 iteration 0 loss 1.8728121519088745
    166 Epoch 39 iteration 100 loss 1.4476994276046753
    167 Epoch 39 iteration 200 loss 1.9898269176483154
    168 Epoch 39 Training loss 2.1024907934743258
    169 Epoch 40 iteration 0 loss 1.8664008378982544
    170 Epoch 40 iteration 100 loss 1.4997611045837402
    171 Epoch 40 iteration 200 loss 1.9541966915130615
    172 Epoch 40 Training loss 2.086313187411815
    173 Evaluation loss 3.1282314096494708
    174 Epoch 41 iteration 0 loss 1.865237832069397
    175 Epoch 41 iteration 100 loss 1.4755399227142334
    176 Epoch 41 iteration 200 loss 1.9337103366851807
    177 Epoch 41 Training loss 2.068258631932244
    178 Epoch 42 iteration 0 loss 1.790804147720337
    179 Epoch 42 iteration 100 loss 1.4380069971084595
    180 Epoch 42 iteration 200 loss 1.9523491859436035
    181 Epoch 42 Training loss 2.0498001934027874
    182 Epoch 43 iteration 0 loss 1.7979768514633179
    183 Epoch 43 iteration 100 loss 1.436006784439087
    184 Epoch 43 iteration 200 loss 1.9101322889328003
    185 Epoch 43 Training loss 2.0354298580230195
    186 Epoch 44 iteration 0 loss 1.7717180252075195
    187 Epoch 44 iteration 100 loss 1.412601351737976
    188 Epoch 44 iteration 200 loss 1.8883790969848633
    189 Epoch 44 Training loss 2.0182710578663032
    190 Epoch 45 iteration 0 loss 1.7614871263504028
    191 Epoch 45 iteration 100 loss 1.3429900407791138
    192 Epoch 45 iteration 200 loss 1.862486720085144
    193 Epoch 45 Training loss 2.0034489605129595
    194 Evaluation loss 3.13050353642062
    195 Epoch 46 iteration 0 loss 1.753187656402588
    196 Epoch 46 iteration 100 loss 1.3810824155807495
    197 Epoch 46 iteration 200 loss 1.8526273965835571
    198 Epoch 46 Training loss 1.9899710891643612
    199 Epoch 47 iteration 0 loss 1.7567869424819946
    200 Epoch 47 iteration 100 loss 1.3430988788604736
    201 Epoch 47 iteration 200 loss 1.8135911226272583
    202 Epoch 47 Training loss 1.9723690433387957
    203 Epoch 48 iteration 0 loss 1.7263280153274536
    204 Epoch 48 iteration 100 loss 1.3430798053741455
    205 Epoch 48 iteration 200 loss 1.8229252099990845
    206 Epoch 48 Training loss 1.9580909331705005
    207 Epoch 49 iteration 0 loss 1.731834888458252
    208 Epoch 49 iteration 100 loss 1.325390100479126
    209 Epoch 49 iteration 200 loss 1.8075029850006104
    210 Epoch 49 Training loss 1.9418853706725143
    211 Epoch 50 iteration 0 loss 1.7218893766403198
    212 Epoch 50 iteration 100 loss 1.2710607051849365
    213 Epoch 50 iteration 200 loss 1.8196479082107544
    214 Epoch 50 Training loss 1.9300463292027463
    215 Evaluation loss 3.1402900424368902
    216 Epoch 51 iteration 0 loss 1.701721429824829
    217 Epoch 51 iteration 100 loss 1.2720820903778076
    218 Epoch 51 iteration 200 loss 1.7759710550308228
    219 Epoch 51 Training loss 1.9192517232508806
    220 Epoch 52 iteration 0 loss 1.7286512851715088
    221 Epoch 52 iteration 100 loss 1.2737478017807007
    222 Epoch 52 iteration 200 loss 1.7545547485351562
    223 Epoch 52 Training loss 1.906238278183267
    224 Epoch 53 iteration 0 loss 1.6672327518463135
    225 Epoch 53 iteration 100 loss 1.3138436079025269
    226 Epoch 53 iteration 200 loss 1.8045201301574707
    227 Epoch 53 Training loss 1.8922825534741075
    228 Epoch 54 iteration 0 loss 1.617557168006897
    229 Epoch 54 iteration 100 loss 1.22885262966156
    230 Epoch 54 iteration 200 loss 1.7750707864761353
    231 Epoch 54 Training loss 1.8807705430479014
    232 Epoch 55 iteration 0 loss 1.66348135471344
    233 Epoch 55 iteration 100 loss 1.2331219911575317
    234 Epoch 55 iteration 200 loss 1.7303975820541382
    235 Epoch 55 Training loss 1.867195544079556
    236 Evaluation loss 3.145431456349013
    237 Epoch 56 iteration 0 loss 1.6259342432022095
    238 Epoch 56 iteration 100 loss 1.2141388654708862
    239 Epoch 56 iteration 200 loss 1.6984847784042358
    240 Epoch 56 Training loss 1.8548133653506713
    241 Epoch 57 iteration 0 loss 1.605487585067749
    242 Epoch 57 iteration 100 loss 1.1920335292816162
    243 Epoch 57 iteration 200 loss 1.7253336906433105
    244 Epoch 57 Training loss 1.8387836396466541
    245 Epoch 58 iteration 0 loss 1.600136160850525
    246 Epoch 58 iteration 100 loss 1.2192472219467163
    247 Epoch 58 iteration 200 loss 1.6888371706008911
    248 Epoch 58 Training loss 1.83046734055076
    249 Epoch 59 iteration 0 loss 1.6042535305023193
    250 Epoch 59 iteration 100 loss 1.2362377643585205
    251 Epoch 59 iteration 200 loss 1.6654771566390991
    252 Epoch 59 Training loss 1.8226244935892273
    253 Epoch 60 iteration 0 loss 1.5602766275405884
    254 Epoch 60 iteration 100 loss 1.201045036315918
    255 Epoch 60 iteration 200 loss 1.6702684164047241
    256 Epoch 60 Training loss 1.8102721190615219
    257 Evaluation loss 3.154303393916162
    258 Epoch 61 iteration 0 loss 1.5679781436920166
    259 Epoch 61 iteration 100 loss 1.2105367183685303
    260 Epoch 61 iteration 200 loss 1.6650742292404175
    261 Epoch 61 Training loss 1.7970227477404426
    262 Epoch 62 iteration 0 loss 1.5734565258026123
    263 Epoch 62 iteration 100 loss 1.1602052450180054
    264 Epoch 62 iteration 200 loss 1.583187222480774
    265 Epoch 62 Training loss 1.787027303402099
    266 Epoch 63 iteration 0 loss 1.563283920288086
    267 Epoch 63 iteration 100 loss 1.1829460859298706
    268 Epoch 63 iteration 200 loss 1.6458944082260132
    269 Epoch 63 Training loss 1.7742324239103342
    270 Epoch 64 iteration 0 loss 1.5429617166519165
    271 Epoch 64 iteration 100 loss 1.1225509643554688
    272 Epoch 64 iteration 200 loss 1.6353931427001953
    273 Epoch 64 Training loss 1.7665018986396424
    274 Epoch 65 iteration 0 loss 1.5284583568572998
    275 Epoch 65 iteration 100 loss 1.1426113843917847
    276 Epoch 65 iteration 200 loss 1.6138485670089722
    277 Epoch 65 Training loss 1.7557591437816458
    278 Evaluation loss 3.166533922994568
    279 Epoch 66 iteration 0 loss 1.5184751749038696
    280 Epoch 66 iteration 100 loss 1.127056360244751
    281 Epoch 66 iteration 200 loss 1.611910343170166
    282 Epoch 66 Training loss 1.7446940747065838
    283 Epoch 67 iteration 0 loss 1.4880752563476562
    284 Epoch 67 iteration 100 loss 1.1075133085250854
    285 Epoch 67 iteration 200 loss 1.6138321161270142
    286 Epoch 67 Training loss 1.7374662356132202
    287 Epoch 68 iteration 0 loss 1.5260978937149048
    288 Epoch 68 iteration 100 loss 1.12235689163208
    289 Epoch 68 iteration 200 loss 1.6129950284957886
    290 Epoch 68 Training loss 1.7253250324901928
    291 Epoch 69 iteration 0 loss 1.5172449350357056
    292 Epoch 69 iteration 100 loss 1.1174883842468262
    293 Epoch 69 iteration 200 loss 1.551174283027649
    294 Epoch 69 Training loss 1.7166664929363027
    295 Epoch 70 iteration 0 loss 1.5006300210952759
    296 Epoch 70 iteration 100 loss 1.0905342102050781
    297 Epoch 70 iteration 200 loss 1.5446460247039795
    298 Epoch 70 Training loss 1.70989819337649
    299 Evaluation loss 3.1750113054724385
    300 Epoch 71 iteration 0 loss 1.4726097583770752
    301 Epoch 71 iteration 100 loss 1.086822509765625
    302 Epoch 71 iteration 200 loss 1.5575647354125977
    303 Epoch 71 Training loss 1.697000935158525
    304 Epoch 72 iteration 0 loss 1.449334979057312
    305 Epoch 72 iteration 100 loss 1.0667144060134888
    306 Epoch 72 iteration 200 loss 1.530726671218872
    307 Epoch 72 Training loss 1.6881878283419123
    308 Epoch 73 iteration 0 loss 1.4603246450424194
    309 Epoch 73 iteration 100 loss 1.0751914978027344
    310 Epoch 73 iteration 200 loss 1.5088605880737305
    311 Epoch 73 Training loss 1.6805761044806562
    312 Epoch 74 iteration 0 loss 1.4748084545135498
    313 Epoch 74 iteration 100 loss 1.0556395053863525
    314 Epoch 74 iteration 200 loss 1.5206905603408813
    315 Epoch 74 Training loss 1.6673887956853506
    316 Epoch 75 iteration 0 loss 1.454646348953247
    317 Epoch 75 iteration 100 loss 1.0396276712417603
    318 Epoch 75 iteration 200 loss 1.518398404121399
    319 Epoch 75 Training loss 1.6633919350661184
    320 Evaluation loss 3.189181657332237
    321 Epoch 76 iteration 0 loss 1.4616646766662598
    322 Epoch 76 iteration 100 loss 0.9838554859161377
    323 Epoch 76 iteration 200 loss 1.4613702297210693
    324 Epoch 76 Training loss 1.6526747506920867
    325 Epoch 77 iteration 0 loss 1.4646761417388916
    326 Epoch 77 iteration 100 loss 1.0383753776550293
    327 Epoch 77 iteration 200 loss 1.5081768035888672
    328 Epoch 77 Training loss 1.6462943129725018
    329 Epoch 78 iteration 0 loss 1.4008097648620605
    330 Epoch 78 iteration 100 loss 1.0147686004638672
    331 Epoch 78 iteration 200 loss 1.5017434358596802
    332 Epoch 78 Training loss 1.6352284007247493
    333 Epoch 79 iteration 0 loss 1.4189144372940063
    334 Epoch 79 iteration 100 loss 1.0126101970672607
    335 Epoch 79 iteration 200 loss 1.4195480346679688
    336 Epoch 79 Training loss 1.628015456811747
    337 Epoch 80 iteration 0 loss 1.4199804067611694
    338 Epoch 80 iteration 100 loss 1.0256879329681396
    339 Epoch 80 iteration 200 loss 1.4564563035964966
    340 Epoch 80 Training loss 1.6227562783981957
    341 Evaluation loss 3.2074876046135703
    342 Epoch 81 iteration 0 loss 1.431972622871399
    343 Epoch 81 iteration 100 loss 1.0110960006713867
    344 Epoch 81 iteration 200 loss 1.4414775371551514
    345 Epoch 81 Training loss 1.6157781071711008
    346 Epoch 82 iteration 0 loss 1.4158073663711548
    347 Epoch 82 iteration 100 loss 0.9702512621879578
    348 Epoch 82 iteration 200 loss 1.4209394454956055
    349 Epoch 82 Training loss 1.605166310639776
    350 Epoch 83 iteration 0 loss 1.3871146440505981
    351 Epoch 83 iteration 100 loss 1.0183656215667725
    352 Epoch 83 iteration 200 loss 1.4292359352111816
    353 Epoch 83 Training loss 1.5961119023327037
    354 Epoch 84 iteration 0 loss 1.3919366598129272
    355 Epoch 84 iteration 100 loss 0.9692129492759705
    356 Epoch 84 iteration 200 loss 1.4092985391616821
    357 Epoch 84 Training loss 1.5897755956223851
    358 Epoch 85 iteration 0 loss 1.355398416519165
    359 Epoch 85 iteration 100 loss 0.9916797280311584
    360 Epoch 85 iteration 200 loss 1.423561453819275
    361 Epoch 85 Training loss 1.5878568289810793
    362 Evaluation loss 3.2138472480503295
    363 Epoch 86 iteration 0 loss 1.351928472518921
    364 Epoch 86 iteration 100 loss 0.9997824430465698
    365 Epoch 86 iteration 200 loss 1.4049323797225952
    366 Epoch 86 Training loss 1.5719682346027806
    367 Epoch 87 iteration 0 loss 1.3508714437484741
    368 Epoch 87 iteration 100 loss 0.9411044716835022
    369 Epoch 87 iteration 200 loss 1.4019731283187866
    370 Epoch 87 Training loss 1.5641802139809575
    371 Epoch 88 iteration 0 loss 1.347946047782898
    372 Epoch 88 iteration 100 loss 0.9493017792701721
    373 Epoch 88 iteration 200 loss 1.3770906925201416
    374 Epoch 88 Training loss 1.5587840858982533
    375 Epoch 89 iteration 0 loss 1.320084571838379
    376 Epoch 89 iteration 100 loss 0.9223963022232056
    377 Epoch 89 iteration 200 loss 1.4065088033676147
    378 Epoch 89 Training loss 1.5548267858027334
    379 Epoch 90 iteration 0 loss 1.3534889221191406
    380 Epoch 90 iteration 100 loss 0.9281108975410461
    381 Epoch 90 iteration 200 loss 1.3821330070495605
    382 Epoch 90 Training loss 1.5474867314671616
    383 Evaluation loss 3.2276618163204667
    384 Epoch 91 iteration 0 loss 1.3667511940002441
    385 Epoch 91 iteration 100 loss 0.8797598481178284
    386 Epoch 91 iteration 200 loss 1.3776274919509888
    387 Epoch 91 Training loss 1.536482189982952
    388 Epoch 92 iteration 0 loss 1.3355433940887451
    389 Epoch 92 iteration 100 loss 0.9130176901817322
    390 Epoch 92 iteration 200 loss 1.3042923212051392
    391 Epoch 92 Training loss 1.5308507835779057
    392 Epoch 93 iteration 0 loss 1.2953367233276367
    393 Epoch 93 iteration 100 loss 0.9194003939628601
    394 Epoch 93 iteration 200 loss 1.3469970226287842
    395 Epoch 93 Training loss 1.519625581403501
    396 Epoch 94 iteration 0 loss 1.322600245475769
    397 Epoch 94 iteration 100 loss 0.9003701210021973
    398 Epoch 94 iteration 200 loss 1.3512846231460571
    399 Epoch 94 Training loss 1.5193673748787049
    400 Epoch 95 iteration 0 loss 1.2789180278778076
    401 Epoch 95 iteration 100 loss 0.9352515339851379
    402 Epoch 95 iteration 200 loss 1.3609877824783325
    403 Epoch 95 Training loss 1.5135782739054082
    404 Evaluation loss 3.2474015759319284
    405 Epoch 96 iteration 0 loss 1.3051612377166748
    406 Epoch 96 iteration 100 loss 0.8885603547096252
    407 Epoch 96 iteration 200 loss 1.3272497653961182
    408 Epoch 96 Training loss 1.5079536183100883
    409 Epoch 97 iteration 0 loss 1.2671339511871338
    410 Epoch 97 iteration 100 loss 0.8706735968589783
    411 Epoch 97 iteration 200 loss 1.305412769317627
    412 Epoch 97 Training loss 1.4974833326540824
    413 Epoch 98 iteration 0 loss 1.308292269706726
    414 Epoch 98 iteration 100 loss 0.9079441428184509
    415 Epoch 98 iteration 200 loss 1.2940715551376343
    416 Epoch 98 Training loss 1.4928753682563118
    417 Epoch 99 iteration 0 loss 1.276250958442688
    418 Epoch 99 iteration 100 loss 0.890657901763916
    419 Epoch 99 iteration 200 loss 1.3286609649658203
    420 Epoch 99 Training loss 1.4852960116094391
    View Code

    6.翻译

     1 def translate_dev(i):
     2     en_sent = " ".join([inv_en_dict[w] for w in dev_en[i]])  #原来的英文
     3     print(en_sent)
     4     cn_sent = " ".join([inv_cn_dict[w] for w in dev_cn[i]])  #原来的中文
     5     print("".join(cn_sent))
     6 
     7     mb_x = torch.from_numpy(np.array(dev_en[i]).reshape(1, -1)).long().to(device)
     8     mb_x_len = torch.from_numpy(np.array([len(dev_en[i])])).long().to(device)
     9     bos = torch.Tensor([[cn_dict["BOS"]]]).long().to(device)
    10 
    11     translation, attn = model.translate(mb_x, mb_x_len, bos)
    12     translation = [inv_cn_dict[i] for i in translation.data.cpu().numpy().reshape(-1)]
    13     trans = []
    14     for word in translation:
    15         if word != "EOS":
    16             trans.append(word)
    17         else:
    18             break
    19     print("".join(trans))           #翻译后的中文
    20 
    21 for i in range(100, 120):
    22     translate_dev(i)
    23     print()

    执行结果如下(样本少加上训练时间太短造成翻译效果不太好):

     1 BOS you have nice skin . EOS
     2 BOS 你 的 皮 膚 真 好 。 EOS
     3 你有很多好。
     4 
     5 BOS you 're UNK correct . EOS
     6 BOS 你 部 分 正 确 。 EOS
     7 你是个好人。
     8 
     9 BOS everyone admired his courage . EOS
    10 BOS 每 個 人 都 佩 服 他 的 勇 氣 。 EOS
    11 他們的電話讓他們一個
    12 
    13 BOS what time is it ? EOS
    14 BOS 几 点 了 ? EOS
    15 它还是什么?
    16 
    17 BOS i 'm free tonight . EOS
    18 BOS 我 今 晚 有 空 。 EOS
    19 我今晚有空。
    20 
    21 BOS here is your book . EOS
    22 BOS 這 是 你 的 書 。 EOS
    23 你的書桌是舊。
    24 
    25 BOS they are at lunch . EOS
    26 BOS 他 们 在 吃 午 饭 。 EOS
    27 他们正在吃米饭。
    28 
    29 BOS this chair is UNK . EOS
    30 BOS 這 把 椅 子 很 UNK 。 EOS
    31 這是真的最好的人。
    32 
    33 BOS it 's pretty heavy . EOS
    34 BOS 它 真 重 。 EOS
    35 它是真的。
    36 
    37 BOS many attended his funeral . EOS
    38 BOS 很 多 人 都 参 加 了 他 的 葬 礼 。 EOS
    39 AI他的襪子。
    40 
    41 BOS training will be provided . EOS
    42 BOS 会 有 训 练 。 EOS
    43 人们的货品造成的。
    44 
    45 BOS someone is watching you . EOS
    46 BOS 有 人 在 看 著 你 。 EOS
    47 有人叫醒汤姆。
    48 
    49 BOS i slapped his face . EOS
    50 BOS 我 摑 了 他 的 臉 。 EOS
    51 我有他的兄弟。
    52 
    53 BOS i like UNK music . EOS
    54 BOS 我 喜 歡 流 行 音 樂 。 EOS
    55 我喜欢狗在家。
    56 
    57 BOS tom had no children . EOS
    58 BOS T o m 沒 有 孩 子 。 EOS
    59 汤姆不需要做什么。
    60 
    61 BOS please lock the door . EOS
    62 BOS 請 把 門 鎖 上 。 EOS
    63 請把門打開。
    64 
    65 BOS tom has calmed down . EOS
    66 BOS 汤 姆 冷 静 下 来 了 。 EOS
    67 汤姆睡着了。
    68 
    69 BOS please speak more loudly . EOS
    70 BOS 請 說 大 聲 一 點 兒 。 EOS
    71 請講慢一點。
    72 
    73 BOS keep next sunday free . EOS
    74 BOS 把 下 周 日 空 出 来 。 EOS
    75 下午可以轉下。
    76 
    77 BOS i made a mistake . EOS
    78 BOS 我 犯 了 一 個 錯 。 EOS
    79 我有些意生。

    7.Encoder Decoder模型(含Attention版本)

    7.1Encoder

    Encoder模型的任务是把输入文字传入embedding层和GRU层,转换成一些hidden states作为后续的context vectors;

     1 class Encoder(nn.Module):
     2     def __init__(self, vocab_size, embed_size, enc_hidden_size, dec_hidden_size, dropout=0.2):
     3         super(Encoder, self).__init__()
     4         self.embed = nn.Embedding(vocab_size, embed_size)
     5         self.rnn = nn.GRU(embed_size, enc_hidden_size, batch_first=True, bidirectional=True)
     6         self.dropout = nn.Dropout(dropout)
     7         self.fc = nn.Linear(enc_hidden_size * 2, dec_hidden_size)
     8 
     9     def forward(self, x, lengths):
    10         sorted_len, sorted_idx = lengths.sort(0, descending=True)
    11         x_sorted = x[sorted_idx.long()]
    12         embedded = self.dropout(self.embed(x_sorted))
    13         
    14         packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, sorted_len.long().cpu().data.numpy(), batch_first=True)
    15         packed_out, hid = self.rnn(packed_embedded)
    16         out, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True)
    17         _, original_idx = sorted_idx.sort(0, descending=False)
    18         out = out[original_idx.long()].contiguous()
    19         hid = hid[:, original_idx.long()].contiguous()
    20         
    21         hid = torch.cat([hid[-2], hid[-1]], dim=1)          #双向,所以拼接
    22         hid = torch.tanh(self.fc(hid)).unsqueeze(0) 
    23 
    24         return out, hid

    7.2Luong Attention

    图中ht是第t个step下GRU的输出,即output,$overline{^{h_{s}}}$是encoder后的context,用QKV模型来解释的话,query就是ht,key和value都是$overline{^{h_{s}}}$;

    根据context vectors和当前的输出hidden states,计算输出;

     1 class Attention(nn.Module):
     2     def __init__(self, enc_hidden_size, dec_hidden_size):
     3         super(Attention, self).__init__()
     4 
     5         self.enc_hidden_size = enc_hidden_size
     6         self.dec_hidden_size = dec_hidden_size
     7 
     8         self.linear_in = nn.Linear(enc_hidden_size*2, dec_hidden_size, bias=False)        #线性变换
     9         self.linear_out = nn.Linear(enc_hidden_size*2 + dec_hidden_size, dec_hidden_size)
    10 
    11     def forward(self, output, context, mask):
    12         # output: [batch_size, output_len, dec_hidden_size]
    13         # context: [batch_size, context_len, 2*enc_hidden_size] *2的原因:GRU是双向的
    14 
    15         batch_size = output.size(0)
    16         output_len = output.size(1)
    17         context_len = context.size(1)
    18 
    19         #context_in: [batch_size, context_len, dec_hidden_size]
    20         context_in = self.linear_in(context.view(batch_size*context_len, -1)).view(batch_size, context_len, -1)
    21 
    22         # context_in.transpose(1,2): [batch_size, dec_hidden_size, context_len]
    23         attn = torch.bmm(output, context_in.transpose(1,2))       #attn: [batch_size, output_len, context_len]
    24 
    25         attn.data.masked_fill(mask, -1e6)
    26 
    27         attn = F.softmax(attn, dim=2)                             #attn: [batch_size, output_len, context_len]
    28 
    29         context = torch.bmm(attn, context)                        #context: [batch_size, output_len, enc_hidden_size]
    30 
    31         output = torch.cat((context, output), dim=2)              #output: [batch_size, output_len, hidden_size*2]
    32 
    33         output = output.view(batch_size*output_len, -1)
    34         output = torch.tanh(self.linear_out(output))
    35         output = output.view(batch_size, output_len, -1)
    36 
    37         return output, attn

    7.3Decoder

     Decoder会根据已经翻译的句子内容和context vectors,来决定下一个输出的单词;

     1 class Decoder(nn.Module):
     2     def __init__(self, vocab_size, embed_size, enc_hidden_size, dec_hidden_size, dropout=0.2):
     3         super(Decoder, self).__init__()
     4         self.embed = nn.Embedding(vocab_size, embed_size)
     5         self.attention = Attention(enc_hidden_size, dec_hidden_size)
     6         self.rnn = nn.GRU(embed_size, hidden_size, batch_first=True)
     7         self.out = nn.Linear(dec_hidden_size, vocab_size)
     8         self.dropout = nn.Dropout(dropout)
     9 
    10     def create_mask(self, x_len, y_len):  # a mask of shape x_len * y_len
    11         
    12         device = x_len.device
    13         max_x_len = x_len.max()
    14         max_y_len = y_len.max()
    15         x_mask = torch.arange(max_x_len, device=x_len.device)[None, :] < x_len[:, None]
    16         y_mask = torch.arange(max_y_len, device=x_len.device)[None, :] < y_len[:, None]
    17         mask = (1 - x_mask[:, :, None] * y_mask[:, None, :]).byte()
    18         return mask
    19     
    20     def forward(self, ctx, ctx_lengths, y, y_lengths, hid):
    21         sorted_len, sorted_idx = y_lengths.sort(0, descending=True)
    22         y_sorted = y[sorted_idx.long()]
    23         hid = hid[:, sorted_idx.long()]
    24         
    25         y_sorted = self.dropout(self.embed(y_sorted)) # batch_size, output_length, embed_size
    26 
    27         packed_seq = nn.utils.rnn.pack_padded_sequence(y_sorted, sorted_len.long().cpu().data.numpy(), batch_first=True)
    28         out, hid = self.rnn(packed_seq, hid)
    29         unpacked, _ = nn.utils.rnn.pad_packed_sequence(out, batch_first=True)
    30         _, original_idx = sorted_idx.sort(0, descending=False)
    31         output_seq = unpacked[original_idx.long()].contiguous()
    32         hid = hid[:, original_idx.long()].contiguous()
    33 
    34         mask = self.create_mask(y_lengths, ctx_lengths)
    35 
    36         output, attn = self.attention(output_seq, ctx, mask)  #根据原来的output_seq和context来计算
    37         output = F.log_softmax(self.out(output), -1)
    38         
    39         return output, hid, attn

    7.4构建Seq2Seq模型把encoder, attention, decoder串到一起

     1 class Seq2Seq(nn.Module):
     2     def __init__(self, encoder, decoder):
     3         super(Seq2Seq, self).__init__()
     4         self.encoder = encoder
     5         self.decoder = decoder
     6         
     7     def forward(self, x, x_lengths, y, y_lengths):
     8         encoder_out, hid = self.encoder(x, x_lengths)
     9         output, hid, attn = self.decoder(ctx=encoder_out, 
    10                     ctx_lengths=x_lengths,
    11                     y=y,
    12                     y_lengths=y_lengths,
    13                     hid=hid)
    14         return output, attn
    15     
    16     def translate(self, x, x_lengths, y, max_length=100):
    17         encoder_out, hid = self.encoder(x, x_lengths)
    18         preds = []
    19         batch_size = x.shape[0]
    20         attns = []
    21         for i in range(max_length):
    22             output, hid, attn = self.decoder(ctx=encoder_out, 
    23                     ctx_lengths=x_lengths,
    24                     y=y,
    25                     y_lengths=torch.ones(batch_size).long().to(y.device),
    26                     hid=hid)
    27             y = output.max(2)[1].view(batch_size, 1)
    28             preds.append(y)
    29             attns.append(attn)
    30         return torch.cat(preds, 1), torch.cat(attns, 1)

    7.5训练函数并调用上面的train函数

     1 dropout = 0.2
     2 embed_size = hidden_size = 100
     3 encoder = Encoder(vocab_size=en_total_words, embed_size=embed_size, enc_hidden_size=hidden_size, dec_hidden_size=hidden_size, dropout=dropout)
     4 decoder = Decoder(vocab_size=cn_total_words, embed_size=embed_size, enc_hidden_size=hidden_size, dec_hidden_size=hidden_size, dropout=dropout)
     5 model = Seq2Seq(encoder, decoder)
     6 model = model.to(device)
     7 loss_fn = LanguageModelCriterion().to(device)
     8 optimizer = torch.optim.Adam(model.parameters())
     9 
    10 train(model, train_data, num_epochs=100)

    训练结果:

      1 Epoch 0 iteration 0 loss 1.3026132583618164
      2 Epoch 0 iteration 100 loss 0.8847191333770752
      3 Epoch 0 iteration 200 loss 1.285671353340149
      4 Epoch 0 Training loss 1.4803871257447
      5 Evaluation loss 3.260634059314127
      6 Epoch 1 iteration 0 loss 1.2434465885162354
      7 Epoch 1 iteration 100 loss 0.8472312092781067
      8 Epoch 1 iteration 200 loss 1.282746434211731
      9 Epoch 1 Training loss 1.4731217075462495
     10 Epoch 2 iteration 0 loss 1.2593930959701538
     11 Epoch 2 iteration 100 loss 0.8484001159667969
     12 Epoch 2 iteration 200 loss 1.2862968444824219
     13 Epoch 2 Training loss 1.466728115795638
     14 Epoch 3 iteration 0 loss 1.2554501295089722
     15 Epoch 3 iteration 100 loss 0.9115875363349915
     16 Epoch 3 iteration 200 loss 1.2563236951828003
     17 Epoch 3 Training loss 1.4607854827943827
     18 Epoch 4 iteration 0 loss 1.217956304550171
     19 Epoch 4 iteration 100 loss 0.8641748428344727
     20 Epoch 4 iteration 200 loss 1.2998305559158325
     21 Epoch 4 Training loss 1.4587181747145395
     22 Epoch 5 iteration 0 loss 1.258739709854126
     23 Epoch 5 iteration 100 loss 0.8705984354019165
     24 Epoch 5 iteration 200 loss 1.2102816104888916
     25 Epoch 5 Training loss 1.4507371513452623
     26 Evaluation loss 3.266629208261664
     27 Epoch 6 iteration 0 loss 1.259811282157898
     28 Epoch 6 iteration 100 loss 0.8492067456245422
     29 Epoch 6 iteration 200 loss 1.3064922094345093
     30 Epoch 6 Training loss 1.4432446458560053
     31 Epoch 7 iteration 0 loss 1.2411160469055176
     32 Epoch 7 iteration 100 loss 0.8373231291770935
     33 Epoch 7 iteration 200 loss 1.2500189542770386
     34 Epoch 7 Training loss 1.436364381060567
     35 Epoch 8 iteration 0 loss 1.1868956089019775
     36 Epoch 8 iteration 100 loss 0.814584493637085
     37 Epoch 8 iteration 200 loss 1.2773609161376953
     38 Epoch 8 Training loss 1.4354508132900903
     39 Epoch 9 iteration 0 loss 1.2234464883804321
     40 Epoch 9 iteration 100 loss 0.797888457775116
     41 Epoch 9 iteration 200 loss 1.2435855865478516
     42 Epoch 9 Training loss 1.424914875345232
     43 Epoch 10 iteration 0 loss 1.2067270278930664
     44 Epoch 10 iteration 100 loss 0.8425077795982361
     45 Epoch 10 iteration 200 loss 1.2325958013534546
     46 Epoch 10 Training loss 1.4212906077384722
     47 Evaluation loss 3.2876189393276327
     48 Epoch 11 iteration 0 loss 1.221406102180481
     49 Epoch 11 iteration 100 loss 0.80806964635849
     50 Epoch 11 iteration 200 loss 1.3028448820114136
     51 Epoch 11 Training loss 1.4154276829998698
     52 Epoch 12 iteration 0 loss 1.1890984773635864
     53 Epoch 12 iteration 100 loss 0.827181875705719
     54 Epoch 12 iteration 200 loss 1.1675362586975098
     55 Epoch 12 Training loss 1.4132606964483012
     56 Epoch 13 iteration 0 loss 1.2002121210098267
     57 Epoch 13 iteration 100 loss 0.8232781291007996
     58 Epoch 13 iteration 200 loss 1.2605061531066895
     59 Epoch 13 Training loss 1.407515715564216
     60 Epoch 14 iteration 0 loss 1.1855664253234863
     61 Epoch 14 iteration 100 loss 0.8178666234016418
     62 Epoch 14 iteration 200 loss 1.2378345727920532
     63 Epoch 14 Training loss 1.3966619677770713
     64 Epoch 15 iteration 0 loss 1.1885008811950684
     65 Epoch 15 iteration 100 loss 0.7523401975631714
     66 Epoch 15 iteration 200 loss 1.1757400035858154
     67 Epoch 15 Training loss 1.3940533722612007
     68 Evaluation loss 3.3011410674061716
     69 Epoch 16 iteration 0 loss 1.185882806777954
     70 Epoch 16 iteration 100 loss 0.8129084706306458
     71 Epoch 16 iteration 200 loss 1.2022055387496948
     72 Epoch 16 Training loss 1.3908352825348185
     73 Epoch 17 iteration 0 loss 1.145820140838623
     74 Epoch 17 iteration 100 loss 0.7933529615402222
     75 Epoch 17 iteration 200 loss 1.1954973936080933
     76 Epoch 17 Training loss 1.3862186002415022
     77 Epoch 18 iteration 0 loss 1.1626101732254028
     78 Epoch 18 iteration 100 loss 0.8041335940361023
     79 Epoch 18 iteration 200 loss 1.1879560947418213
     80 Epoch 18 Training loss 1.3828558699833502
     81 Epoch 19 iteration 0 loss 1.1661605834960938
     82 Epoch 19 iteration 100 loss 0.7746578454971313
     83 Epoch 19 iteration 200 loss 1.167975902557373
     84 Epoch 19 Training loss 1.3737090146397222
     85 Epoch 20 iteration 0 loss 1.1992604732513428
     86 Epoch 20 iteration 100 loss 0.7750277519226074
     87 Epoch 20 iteration 200 loss 1.1533249616622925
     88 Epoch 20 Training loss 1.3699699581049805
     89 Evaluation loss 3.316624780553762
     90 Epoch 21 iteration 0 loss 1.182730793952942
     91 Epoch 21 iteration 100 loss 0.7664387822151184
     92 Epoch 21 iteration 200 loss 1.1734970808029175
     93 Epoch 21 Training loss 1.3634166858854262
     94 Epoch 22 iteration 0 loss 1.1587318181991577
     95 Epoch 22 iteration 100 loss 0.7660608291625977
     96 Epoch 22 iteration 200 loss 1.1832681894302368
     97 Epoch 22 Training loss 1.3601878647219552
     98 Epoch 23 iteration 0 loss 1.123557209968567
     99 Epoch 23 iteration 100 loss 0.7884796857833862
    100 Epoch 23 iteration 200 loss 1.131569266319275
    101 Epoch 23 Training loss 1.3543664767232568
    102 Epoch 24 iteration 0 loss 1.1566004753112793
    103 Epoch 24 iteration 100 loss 0.7894638180732727
    104 Epoch 24 iteration 200 loss 1.1293442249298096
    105 Epoch 24 Training loss 1.3513351050205646
    106 Epoch 25 iteration 0 loss 1.1237646341323853
    107 Epoch 25 iteration 100 loss 0.7442751526832581
    108 Epoch 25 iteration 200 loss 1.1396199464797974
    109 Epoch 25 Training loss 1.3436930389138495
    110 Evaluation loss 3.331576243054354
    111 Epoch 26 iteration 0 loss 1.1391510963439941
    112 Epoch 26 iteration 100 loss 0.7658866047859192
    113 Epoch 26 iteration 200 loss 1.130005121231079
    114 Epoch 26 Training loss 1.3387227258896204
    115 Epoch 27 iteration 0 loss 1.086417555809021
    116 Epoch 27 iteration 100 loss 0.7512990236282349
    117 Epoch 27 iteration 200 loss 1.1055928468704224
    118 Epoch 27 Training loss 1.3332018813254016
    119 Epoch 28 iteration 0 loss 1.1308163404464722
    120 Epoch 28 iteration 100 loss 0.7653459310531616
    121 Epoch 28 iteration 200 loss 1.1437530517578125
    122 Epoch 28 Training loss 1.3318316266073582
    123 Epoch 29 iteration 0 loss 1.1284910440444946
    124 Epoch 29 iteration 100 loss 0.7385256886482239
    125 Epoch 29 iteration 200 loss 1.076254963874817
    126 Epoch 29 Training loss 1.327704983812702
    127 Epoch 30 iteration 0 loss 1.1279666423797607
    128 Epoch 30 iteration 100 loss 0.7510428428649902
    129 Epoch 30 iteration 200 loss 1.10474693775177
    130 Epoch 30 Training loss 1.3247037412152105
    131 Evaluation loss 3.345638094832775
    132 Epoch 31 iteration 0 loss 1.1144018173217773
    133 Epoch 31 iteration 100 loss 0.7183322906494141
    134 Epoch 31 iteration 200 loss 1.1657849550247192
    135 Epoch 31 Training loss 1.3181022928511037
    136 Epoch 32 iteration 0 loss 1.1624877452850342
    137 Epoch 32 iteration 100 loss 0.6971022486686707
    138 Epoch 32 iteration 200 loss 1.1033793687820435
    139 Epoch 32 Training loss 1.313637083400949
    140 Epoch 33 iteration 0 loss 1.0961930751800537
    141 Epoch 33 iteration 100 loss 0.7509954571723938
    142 Epoch 33 iteration 200 loss 1.0901885032653809
    143 Epoch 33 Training loss 1.3105013603183797
    144 Epoch 34 iteration 0 loss 1.0936028957366943
    145 Epoch 34 iteration 100 loss 0.7300226092338562
    146 Epoch 34 iteration 200 loss 1.094140648841858
    147 Epoch 34 Training loss 1.3085180236466905
    148 Epoch 35 iteration 0 loss 1.1358038187026978
    149 Epoch 35 iteration 100 loss 0.6928472518920898
    150 Epoch 35 iteration 200 loss 1.1031907796859741
    151 Epoch 35 Training loss 1.2983291715098229
    152 Evaluation loss 3.3654267819449917
    153 Epoch 36 iteration 0 loss 1.0817443132400513
    154 Epoch 36 iteration 100 loss 0.7034777998924255
    155 Epoch 36 iteration 200 loss 1.1244701147079468
    156 Epoch 36 Training loss 1.294685624884497
    157 Epoch 37 iteration 0 loss 1.0067986249923706
    158 Epoch 37 iteration 100 loss 0.6711763739585876
    159 Epoch 37 iteration 200 loss 1.0877138376235962
    160 Epoch 37 Training loss 1.2908666166705178
    161 Epoch 38 iteration 0 loss 1.0796058177947998
    162 Epoch 38 iteration 100 loss 0.6984289288520813
    163 Epoch 38 iteration 200 loss 1.0992212295532227
    164 Epoch 38 Training loss 1.289693898836594
    165 Epoch 39 iteration 0 loss 1.1193760633468628
    166 Epoch 39 iteration 100 loss 0.7441080212593079
    167 Epoch 39 iteration 200 loss 1.0557031631469727
    168 Epoch 39 Training loss 1.287817969907393
    169 Epoch 40 iteration 0 loss 1.0878312587738037
    170 Epoch 40 iteration 100 loss 0.7390894889831543
    171 Epoch 40 iteration 200 loss 1.0931909084320068
    172 Epoch 40 Training loss 1.281862247889987
    173 Evaluation loss 3.3775776429435034
    174 Epoch 41 iteration 0 loss 1.1135987043380737
    175 Epoch 41 iteration 100 loss 0.6786257028579712
    176 Epoch 41 iteration 200 loss 1.056801676750183
    177 Epoch 41 Training loss 1.2759448500226234
    178 Epoch 42 iteration 0 loss 1.0649049282073975
    179 Epoch 42 iteration 100 loss 0.6774815320968628
    180 Epoch 42 iteration 200 loss 1.0807018280029297
    181 Epoch 42 Training loss 1.2713608683004023
    182 Epoch 43 iteration 0 loss 1.0711919069290161
    183 Epoch 43 iteration 100 loss 0.6655244827270508
    184 Epoch 43 iteration 200 loss 1.0616692304611206
    185 Epoch 43 Training loss 1.2693709204800718
    186 Epoch 44 iteration 0 loss 1.0423146486282349
    187 Epoch 44 iteration 100 loss 0.7055337429046631
    188 Epoch 44 iteration 200 loss 1.0746649503707886
    189 Epoch 44 Training loss 1.2649514066760854
    190 Epoch 45 iteration 0 loss 1.0937353372573853
    191 Epoch 45 iteration 100 loss 0.6939021348953247
    192 Epoch 45 iteration 200 loss 1.1060905456542969
    193 Epoch 45 Training loss 1.2591645727085945
    194 Evaluation loss 3.393269126438251
    195 Epoch 46 iteration 0 loss 1.1005926132202148
    196 Epoch 46 iteration 100 loss 0.6948174238204956
    197 Epoch 46 iteration 200 loss 1.0675958395004272
    198 Epoch 46 Training loss 1.2555609077507983
    199 Epoch 47 iteration 0 loss 1.0566778182983398
    200 Epoch 47 iteration 100 loss 0.6904436349868774
    201 Epoch 47 iteration 200 loss 1.0723766088485718
    202 Epoch 47 Training loss 1.2552127211907091
    203 Epoch 48 iteration 0 loss 1.0497757196426392
    204 Epoch 48 iteration 100 loss 0.6351101398468018
    205 Epoch 48 iteration 200 loss 1.0661102533340454
    206 Epoch 48 Training loss 1.2479233140629313
    207 Epoch 49 iteration 0 loss 1.0470858812332153
    208 Epoch 49 iteration 100 loss 0.6707669496536255
    209 Epoch 49 iteration 200 loss 1.063056230545044
    210 Epoch 49 Training loss 1.2453716254928995
    211 Epoch 50 iteration 0 loss 1.0854698419570923
    212 Epoch 50 iteration 100 loss 0.6165581345558167
    213 Epoch 50 iteration 200 loss 1.0804699659347534
    214 Epoch 50 Training loss 1.243395479327141
    215 Evaluation loss 3.4102849531750494
    216 Epoch 51 iteration 0 loss 1.0279858112335205
    217 Epoch 51 iteration 100 loss 0.6448107957839966
    218 Epoch 51 iteration 200 loss 1.0390673875808716
    219 Epoch 51 Training loss 1.2358123475125082
    220 Epoch 52 iteration 0 loss 1.0429105758666992
    221 Epoch 52 iteration 100 loss 0.7124451994895935
    222 Epoch 52 iteration 200 loss 1.061672329902649
    223 Epoch 52 Training loss 1.233326693902576
    224 Epoch 53 iteration 0 loss 1.0357102155685425
    225 Epoch 53 iteration 100 loss 0.6381393074989319
    226 Epoch 53 iteration 200 loss 1.0036036968231201
    227 Epoch 53 Training loss 1.2297246983027847
    228 Epoch 54 iteration 0 loss 1.0590764284133911
    229 Epoch 54 iteration 100 loss 0.6603832840919495
    230 Epoch 54 iteration 200 loss 1.0215944051742554
    231 Epoch 54 Training loss 1.227340017322883
    232 Epoch 55 iteration 0 loss 1.0460106134414673
    233 Epoch 55 iteration 100 loss 0.67122882604599
    234 Epoch 55 iteration 200 loss 1.0344772338867188
    235 Epoch 55 Training loss 1.2244369935263697
    236 Evaluation loss 3.4193579365732183
    237 Epoch 56 iteration 0 loss 1.032409429550171
    238 Epoch 56 iteration 100 loss 0.6183319091796875
    239 Epoch 56 iteration 200 loss 0.9782896637916565
    240 Epoch 56 Training loss 1.2178285635372972
    241 Epoch 57 iteration 0 loss 1.0382548570632935
    242 Epoch 57 iteration 100 loss 0.6902874708175659
    243 Epoch 57 iteration 200 loss 1.016508936882019
    244 Epoch 57 Training loss 1.214647328633978
    245 Epoch 58 iteration 0 loss 1.0595533847808838
    246 Epoch 58 iteration 100 loss 0.6885846853256226
    247 Epoch 58 iteration 200 loss 1.0221766233444214
    248 Epoch 58 Training loss 1.2100419675457097
    249 Epoch 59 iteration 0 loss 1.014621615409851
    250 Epoch 59 iteration 100 loss 0.602800190448761
    251 Epoch 59 iteration 200 loss 1.037442684173584
    252 Epoch 59 Training loss 1.2131489746903632
    253 Epoch 60 iteration 0 loss 1.0217640399932861
    254 Epoch 60 iteration 100 loss 0.6246439814567566
    255 Epoch 60 iteration 200 loss 1.00297212600708
    256 Epoch 60 Training loss 1.204652290841725
    257 Evaluation loss 3.4292290158399075
    258 Epoch 61 iteration 0 loss 0.9992070198059082
    259 Epoch 61 iteration 100 loss 0.645142138004303
    260 Epoch 61 iteration 200 loss 0.9961024522781372
    261 Epoch 61 Training loss 1.2066241352823563
    262 Epoch 62 iteration 0 loss 0.9980950951576233
    263 Epoch 62 iteration 100 loss 0.6504135131835938
    264 Epoch 62 iteration 200 loss 1.000308632850647
    265 Epoch 62 Training loss 1.1984729866171178
    266 Epoch 63 iteration 0 loss 0.9869410395622253
    267 Epoch 63 iteration 100 loss 0.6618863344192505
    268 Epoch 63 iteration 200 loss 0.981200635433197
    269 Epoch 63 Training loss 1.1967191445048035
    270 Epoch 64 iteration 0 loss 0.9695953130722046
    271 Epoch 64 iteration 100 loss 0.6359274387359619
    272 Epoch 64 iteration 200 loss 0.9904515743255615
    273 Epoch 64 Training loss 1.194521029171779
    274 Epoch 65 iteration 0 loss 0.9505796432495117
    275 Epoch 65 iteration 100 loss 0.6068794131278992
    276 Epoch 65 iteration 200 loss 0.980348527431488
    277 Epoch 65 Training loss 1.189270519480765
    278 Evaluation loss 3.454153442993674
    279 Epoch 66 iteration 0 loss 1.0304545164108276
    280 Epoch 66 iteration 100 loss 0.6792566776275635
    281 Epoch 66 iteration 200 loss 0.9789241552352905
    282 Epoch 66 Training loss 1.1869953296382767
    283 Epoch 67 iteration 0 loss 0.957666277885437
    284 Epoch 67 iteration 100 loss 0.584879994392395
    285 Epoch 67 iteration 200 loss 1.0174148082733154
    286 Epoch 67 Training loss 1.184179561090835
    287 Epoch 68 iteration 0 loss 1.043166995048523
    288 Epoch 68 iteration 100 loss 0.6168758869171143
    289 Epoch 68 iteration 200 loss 1.0030053853988647
    290 Epoch 68 Training loss 1.1824355462851552
    291 Epoch 69 iteration 0 loss 1.0165300369262695
    292 Epoch 69 iteration 100 loss 0.6542645692825317
    293 Epoch 69 iteration 200 loss 1.0191236734390259
    294 Epoch 69 Training loss 1.176021675397731
    295 Epoch 70 iteration 0 loss 0.9590736031532288
    296 Epoch 70 iteration 100 loss 0.6157773733139038
    297 Epoch 70 iteration 200 loss 1.0451829433441162
    298 Epoch 70 Training loss 1.1732503092442255
    299 Evaluation loss 3.4715566423642277
    300 Epoch 71 iteration 0 loss 0.971733570098877
    301 Epoch 71 iteration 100 loss 0.5589802265167236
    302 Epoch 71 iteration 200 loss 1.0018212795257568
    303 Epoch 71 Training loss 1.1694891346833023
    304 Epoch 72 iteration 0 loss 1.0042874813079834
    305 Epoch 72 iteration 100 loss 0.6543828248977661
    306 Epoch 72 iteration 200 loss 0.968835175037384
    307 Epoch 72 Training loss 1.1667191714442264
    308 Epoch 73 iteration 0 loss 0.9512341022491455
    309 Epoch 73 iteration 100 loss 0.5809782147407532
    310 Epoch 73 iteration 200 loss 0.9460022449493408
    311 Epoch 73 Training loss 1.165780372824424
    312 Epoch 74 iteration 0 loss 0.9838390946388245
    313 Epoch 74 iteration 100 loss 0.6115572452545166
    314 Epoch 74 iteration 200 loss 0.9821975827217102
    315 Epoch 74 Training loss 1.1619031632185661
    316 Epoch 75 iteration 0 loss 0.9615085124969482
    317 Epoch 75 iteration 100 loss 0.5715279579162598
    318 Epoch 75 iteration 200 loss 0.9673617482185364
    319 Epoch 75 Training loss 1.1592393025041507
    320 Evaluation loss 3.4810480487503015
    321 Epoch 76 iteration 0 loss 0.9920525550842285
    322 Epoch 76 iteration 100 loss 0.6243174076080322
    323 Epoch 76 iteration 200 loss 0.9598985910415649
    324 Epoch 76 Training loss 1.1506768550866349
    325 Epoch 77 iteration 0 loss 0.9717826843261719
    326 Epoch 77 iteration 100 loss 0.5903583765029907
    327 Epoch 77 iteration 200 loss 0.9472079873085022
    328 Epoch 77 Training loss 1.151601228059984
    329 Epoch 78 iteration 0 loss 0.9331899881362915
    330 Epoch 78 iteration 100 loss 0.6189018487930298
    331 Epoch 78 iteration 200 loss 0.9951513409614563
    332 Epoch 78 Training loss 1.1474281610158772
    333 Epoch 79 iteration 0 loss 0.9012037515640259
    334 Epoch 79 iteration 100 loss 0.5837778449058533
    335 Epoch 79 iteration 200 loss 0.9066386818885803
    336 Epoch 79 Training loss 1.142656700489289
    337 Epoch 80 iteration 0 loss 0.9931736588478088
    338 Epoch 80 iteration 100 loss 0.5927265882492065
    339 Epoch 80 iteration 200 loss 0.938447892665863
    340 Epoch 80 Training loss 1.1434640075302192
    341 Evaluation loss 3.491394963812294
    342 Epoch 81 iteration 0 loss 0.9227023720741272
    343 Epoch 81 iteration 100 loss 0.5467157363891602
    344 Epoch 81 iteration 200 loss 0.9126712083816528
    345 Epoch 81 Training loss 1.1427882346320761
    346 Epoch 82 iteration 0 loss 0.9733406901359558
    347 Epoch 82 iteration 100 loss 0.564643144607544
    348 Epoch 82 iteration 200 loss 0.9918593764305115
    349 Epoch 82 Training loss 1.1362837826371996
    350 Epoch 83 iteration 0 loss 0.9489978551864624
    351 Epoch 83 iteration 100 loss 0.5791521668434143
    352 Epoch 83 iteration 200 loss 0.9270768165588379
    353 Epoch 83 Training loss 1.136011173156159
    354 Epoch 84 iteration 0 loss 0.9410436749458313
    355 Epoch 84 iteration 100 loss 0.5409624576568604
    356 Epoch 84 iteration 200 loss 0.8918321132659912
    357 Epoch 84 Training loss 1.128506034776273
    358 Epoch 85 iteration 0 loss 0.9554007053375244
    359 Epoch 85 iteration 100 loss 0.571331799030304
    360 Epoch 85 iteration 200 loss 0.9672144055366516
    361 Epoch 85 Training loss 1.1277133646535586
    362 Evaluation loss 3.5075310684848158
    363 Epoch 86 iteration 0 loss 0.9104467630386353
    364 Epoch 86 iteration 100 loss 0.5656437277793884
    365 Epoch 86 iteration 200 loss 0.9324206113815308
    366 Epoch 86 Training loss 1.126375005188875
    367 Epoch 87 iteration 0 loss 0.9339620471000671
    368 Epoch 87 iteration 100 loss 0.5636867880821228
    369 Epoch 87 iteration 200 loss 0.8825109601020813
    370 Epoch 87 Training loss 1.1222316938253494
    371 Epoch 88 iteration 0 loss 0.904504120349884
    372 Epoch 88 iteration 100 loss 0.5706378221511841
    373 Epoch 88 iteration 200 loss 0.9415532350540161
    374 Epoch 88 Training loss 1.1215731092872845
    375 Epoch 89 iteration 0 loss 0.9489354491233826
    376 Epoch 89 iteration 100 loss 0.6389216184616089
    377 Epoch 89 iteration 200 loss 0.8783397078514099
    378 Epoch 89 Training loss 1.1199876689692458
    379 Epoch 90 iteration 0 loss 0.909376323223114
    380 Epoch 90 iteration 100 loss 0.6190019249916077
    381 Epoch 90 iteration 200 loss 0.9191233515739441
    382 Epoch 90 Training loss 1.1181392741798546
    383 Evaluation loss 3.508678977926201
    384 Epoch 91 iteration 0 loss 0.9080389738082886
    385 Epoch 91 iteration 100 loss 0.5580074191093445
    386 Epoch 91 iteration 200 loss 0.9494779706001282
    387 Epoch 91 Training loss 1.1147855064570311
    388 Epoch 92 iteration 0 loss 0.900802731513977
    389 Epoch 92 iteration 100 loss 0.573580801486969
    390 Epoch 92 iteration 200 loss 0.9199456572532654
    391 Epoch 92 Training loss 1.1107969536786537
    392 Epoch 93 iteration 0 loss 0.9345868229866028
    393 Epoch 93 iteration 100 loss 0.5590959787368774
    394 Epoch 93 iteration 200 loss 0.90354984998703
    395 Epoch 93 Training loss 1.105984925602608
    396 Epoch 94 iteration 0 loss 0.9008861780166626
    397 Epoch 94 iteration 100 loss 0.5503742098808289
    398 Epoch 94 iteration 200 loss 0.8791723251342773
    399 Epoch 94 Training loss 1.1053885063813342
    400 Epoch 95 iteration 0 loss 0.899246096611023
    401 Epoch 95 iteration 100 loss 0.6236768364906311
    402 Epoch 95 iteration 200 loss 0.8661567568778992
    403 Epoch 95 Training loss 1.0993307278503
    404 Evaluation loss 3.5332032706941585
    405 Epoch 96 iteration 0 loss 0.8837733864784241
    406 Epoch 96 iteration 100 loss 0.5473974943161011
    407 Epoch 96 iteration 200 loss 0.9025910496711731
    408 Epoch 96 Training loss 1.0998253373283113
    409 Epoch 97 iteration 0 loss 0.922965407371521
    410 Epoch 97 iteration 100 loss 0.5556969046592712
    411 Epoch 97 iteration 200 loss 0.9027858972549438
    412 Epoch 97 Training loss 1.096842199480861
    413 Epoch 98 iteration 0 loss 0.8947715759277344
    414 Epoch 98 iteration 100 loss 0.5312948822975159
    415 Epoch 98 iteration 200 loss 0.9379984736442566
    416 Epoch 98 Training loss 1.0949072217540066
    417 Epoch 99 iteration 0 loss 0.8829227685928345
    418 Epoch 99 iteration 100 loss 0.5451477766036987
    419 Epoch 99 iteration 200 loss 0.8783729672431946
    420 Epoch 99 Training loss 1.092216717956385
    View Code

    7.6调用上面的translate_dev函数

    1 for i in range(100,120):
    2     translate_dev(i)
    3     print()

    执行结果如下:

     1 BOS you have nice skin . EOS
     2 BOS 你 的 皮 膚 真 好 。 EOS
     3 你有足球的食物都好了
     4 
     5 BOS you 're UNK correct . EOS
     6 BOS 你 部 分 正 确 。 EOS
     7 你是个好厨师。
     8 
     9 BOS everyone admired his courage . EOS
    10 BOS 每 個 人 都 佩 服 他 的 勇 氣 。 EOS
    11 他們每個人都很好奇。
    12 
    13 BOS what time is it ? EOS
    14 BOS 几 点 了 ? EOS
    15 它是什麼?
    16 
    17 BOS i 'm free tonight . EOS
    18 BOS 我 今 晚 有 空 。 EOS
    19 我今晚沒有空。
    20 
    21 BOS here is your book . EOS
    22 BOS 這 是 你 的 書 。 EOS
    23 你這附書是讀書。
    24 
    25 BOS they are at lunch . EOS
    26 BOS 他 们 在 吃 午 饭 。 EOS
    27 他们在吃米饭。
    28 
    29 BOS this chair is UNK . EOS
    30 BOS 這 把 椅 子 很 UNK 。 EOS
    31 這是真的最好的。
    32 
    33 BOS it 's pretty heavy . EOS
    34 BOS 它 真 重 。 EOS
    35 它很有。
    36 
    37 BOS many attended his funeral . EOS
    38 BOS 很 多 人 都 参 加 了 他 的 葬 礼 。 EOS
    39 仔细把他的罪恶着。
    40 
    41 BOS training will be provided . EOS
    42 BOS 会 有 训 练 。 EOS
    43 克林變得餓了。
    44 
    45 BOS someone is watching you . EOS
    46 BOS 有 人 在 看 著 你 。 EOS
    47 有人在信封信。
    48 
    49 BOS i slapped his face . EOS
    50 BOS 我 摑 了 他 的 臉 。 EOS
    51 我是他的兄弟。
    52 
    53 BOS i like UNK music . EOS
    54 BOS 我 喜 歡 流 行 音 樂 。 EOS
    55 我喜歡音樂。
    56 
    57 BOS tom had no children . EOS
    58 BOS T o m 沒 有 孩 子 。 EOS
    59 Tom沒有太累。
    60 
    61 BOS please lock the door . EOS
    62 BOS 請 把 門 鎖 上 。 EOS
    63 請關門。
    64 
    65 BOS tom has calmed down . EOS
    66 BOS 汤 姆 冷 静 下 来 了 。 EOS
    67 汤姆向伤极了。
    68 
    69 BOS please speak more loudly . EOS
    70 BOS 請 說 大 聲 一 點 兒 。 EOS
    71 請講話。
    72 
    73 BOS keep next sunday free . EOS
    74 BOS 把 下 周 日 空 出 来 。 EOS
    75 下周举一直流出席。
    76 
    77 BOS i made a mistake . EOS
    78 BOS 我 犯 了 一 個 錯 。 EOS
    79 我有一个梦意。

    翻译结果依然一般。

  • 相关阅读:
    vue 快速开发
    java 查es
    es filter 的使用
    es查询例子
    es的基本查询
    linux top命令VIRT,RES,SHR,DATA的含义
    Redis和MC的对比
    决TIME_WAIT过多造成的问题
    MariaDB yum 安装
    more 命令相关
  • 原文地址:https://www.cnblogs.com/cxq1126/p/13565961.html
Copyright © 2011-2022 走看看