zoukankan      html  css  js  c++  java
  • transformer代码笔记----decoder.py

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    
    from config import IGNORE_ID
    from .attention import MultiHeadAttention
    from .module import PositionalEncoding, PositionwiseFeedForward
    from .utils import get_attn_key_pad_mask, get_attn_pad_mask, get_non_pad_mask, get_subsequent_mask, pad_list
    
    
    # filename = 'bigram_freq.pkl'
    # print('loading {}...'.format(filename))
    # with open(filename, 'rb') as file:
    #     bigram_freq = pickle.load(file)
    
    
    class Decoder(nn.Module):
        ''' A decoder model with self attention mechanism. '''
    
        def __init__(
                self, sos_id=0, eos_id=1,
                n_tgt_vocab=4335, d_word_vec=512,
                n_layers=6, n_head=8, d_k=64, d_v=64,
                d_model=512, d_inner=2048, dropout=0.1,
                tgt_emb_prj_weight_sharing=True,
                pe_maxlen=5000):
            super(Decoder, self).__init__()
            # parameters 参数实例化
            self.sos_id = sos_id  # Start of Sentence
            self.eos_id = eos_id  # End of Sentence
            self.n_tgt_vocab = n_tgt_vocab
            self.d_word_vec = d_word_vec
            self.n_layers = n_layers
            self.n_head = n_head
            self.d_k = d_k
            self.d_v = d_v
            self.d_model = d_model
            self.d_inner = d_inner
            self.dropout = dropout
            self.tgt_emb_prj_weight_sharing = tgt_emb_prj_weight_sharing
            self.pe_maxlen = pe_maxlen
    
            self.tgt_word_emb = nn.Embedding(n_tgt_vocab, d_word_vec)
            self.positional_encoding = PositionalEncoding(d_model, max_len=pe_maxlen)
            self.dropout = nn.Dropout(dropout)
    
            self.layer_stack = nn.ModuleList([
                DecoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
                for _ in range(n_layers)])   #解码器个数
    
            self.tgt_word_prj = nn.Linear(d_model, n_tgt_vocab, bias=False)  #线性变换
            nn.init.xavier_normal_(self.tgt_word_prj.weight)  #初始化
    
            if tgt_emb_prj_weight_sharing:  #默认为true
                # Share the weight matrix between target word embedding & the final logit dense layer
                self.tgt_word_prj.weight = self.tgt_word_emb.weight  #将目标词嵌入权重共享给线性函数的权重
                self.x_logit_scale = (d_model ** -0.5) #
            else:
                self.x_logit_scale = 1.
    
        def preprocess(self, padded_input): #预处理
            """Generate decoder input and output label from padded_input
            Add <sos> to decoder input, and add <eos> to decoder output label
            """
            ys = [y[y != IGNORE_ID] for y in padded_input]  # parse padded ys  IGNOR_ID=-1
            # prepare input and output word sequences with sos/eos IDs
            eos = ys[0].new([self.eos_id]) #定义新的零阶tensor
            # .new():创建一个新的Tensor,该Tensor的type和device都和原有Tensor一致,且无内容。
            sos = ys[0].new([self.sos_id])
            ys_in = [torch.cat([sos, y], dim=0) for y in ys] #合并两个tensor,添加起始标签
            ys_out = [torch.cat([y, eos], dim=0) for y in ys] #添加结束标签
            # padding for ys with -1
            # pys: utt x olen
            ys_in_pad = pad_list(ys_in, self.eos_id) #ys_in:填充对象;self.eos_id:填充值
            ys_out_pad = pad_list(ys_out, IGNORE_ID)
            assert ys_in_pad.size() == ys_out_pad.size() #assert判断后面代码的布尔值,若为假就报错
            return ys_in_pad, ys_out_pad  #返回添加标签和填充后的数据
    
        def forward(self, padded_input, encoder_padded_outputs,
                    encoder_input_lengths, return_attns=False):
            """
            Args:
                padded_input: N x To
                encoder_padded_outputs: N x Ti x H
            Returns:
            """
            dec_slf_attn_list, dec_enc_attn_list = [], [] #定义解码器注意力和编码解码注意力列表
    
            # Get Deocder Input and Output
            ys_in_pad, ys_out_pad = self.preprocess(padded_input)  #提取预处理后的数据
    
            # Prepare masks
            non_pad_mask = get_non_pad_mask(ys_in_pad, pad_idx=self.eos_id) #对输入mask
    
            slf_attn_mask_subseq = get_subsequent_mask(ys_in_pad) #对目标序列mask
            slf_attn_mask_keypad = get_attn_key_pad_mask(seq_k=ys_in_pad,
                                                         seq_q=ys_in_pad,
                                                         pad_idx=self.eos_id) #对key mask
            slf_attn_mask = (slf_attn_mask_keypad + slf_attn_mask_subseq).gt(0) #自注意力mask
    
            output_length = ys_in_pad.size(1)
            dec_enc_attn_mask = get_attn_pad_mask(encoder_padded_outputs,
                                                  encoder_input_lengths,
                                                  output_length) #编码解码注意力mask
    
            # Forward
            dec_output = self.dropout(self.tgt_word_emb(ys_in_pad) * self.x_logit_scale +
                                      self.positional_encoding(ys_in_pad)) #输入等词嵌入加位置编码
    
            for dec_layer in self.layer_stack: #进入decoder层
                dec_outpsk=slf_aut, dec_slf_attn, dec_enc_attn = dec_layer(
                    dec_output, encoder_padded_outputs,
                    non_pad_mask=non_pad_mask,
                    slf_attn_mattn_mask,
                    dec_enc_attn_mask=dec_enc_attn_mask)
    
                if return_attns: #默认False
                    dec_slf_attn_list += [dec_slf_attn]
                    dec_enc_attn_list += [dec_enc_attn]
    
            # before softmax
            seq_logit = self.tgt_word_prj(dec_output)#编码器的输出放入线性网络中
    
            # Return
            pred, gold = seq_logit, ys_out_pad #得到目标值和预测值
    
            if return_attns:
                return pred, gold, dec_slf_attn_list, dec_enc_attn_list
            return pred, gold
    
        def recognize_beam(self, encoder_outputs, char_list, args):
            """Beam search, decode one utterence now.
            Args:
                encoder_outputs: T x H
                char_list: list of character
                args: args.beam
            Returns:
                nbest_hyps:
            """
            # search params
            beam = args.beam_size
            nbest = args.nbest
            if args.decode_max_len == 0:
                maxlen = encoder_outputs.size(0)
            else:
                maxlen = args.decode_max_len
    
            encoder_outputs = encoder_outputs.unsqueeze(0) #unsqueeze(0)对零维添加一个维度
    
            # prepare sos
            # 在数据中添加起始标志
            ys = torch.ones(1, 1).fill_(self.sos_id).type_as(encoder_outputs).long()
            #.ones(size):生成一个全是1的tensor;a.type_as(b):将a的数据类型转换为b的数据类型;
            #a.fill_(b):将a中的数据替换为b;long():数据类型
    
            # yseq: 1xT
            hyp = {'score': 0.0, 'yseq': ys}
            hyps = [hyp]
            ended_hyps = []
    
            for i in range(maxlen):
                hyps_best_kept = []
                for hyp in hyps:
                    ys = hyp['yseq']  # 1 x i
                    # last_id = ys.cpu().numpy()[0][-1]
                    # freq = bigram_freq[last_id]
                    # freq = torch.log(torch.from_numpy(freq))
                    # # print(freq.dtype)
                    # freq = freq.type(torch.float).to(device)
                    # print(freq.dtype)
                    # print('freq.size(): ' + str(freq.size()))
                    # print('freq: ' + str(freq))
                    # -- Prepare masks
                    non_pad_mask = torch.ones_like(ys).float().unsqueeze(-1)  # 1xix1
                    slf_attn_mask = get_subsequent_mask(ys)
    
                    # -- Forward
                    dec_output = self.dropout(
                        self.tgt_word_emb(ys) * self.x_logit_scale +
                        self.positional_encoding(ys))
    
                    for dec_layer in self.layer_stack:
                        dec_output, _, _ = dec_layer(
                            dec_output, encoder_outputs,
                            non_pad_mask=non_pad_mask,
                            slf_attn_mask=slf_attn_mask,
                            dec_enc_attn_mask=None)
    
                    seq_logit = self.tgt_word_prj(dec_output[:, -1])
                    # local_scores = F.log_softmax(seq_logit, dim=1)
                    local_scores = F.log_softmax(seq_logit, dim=1)
                    # print('local_scores.size(): ' + str(local_scores.size()))
                    # local_scores += freq
                    # print('local_scores: ' + str(local_scores))
    
                    # topk scores
                    local_best_scores, local_best_ids = torch.topk(
                        local_scores, beam, dim=1)
    
                    for j in range(beam):
                        new_hyp = {}
                        new_hyp['score'] = hyp['score'] + local_best_scores[0, j]
                        new_hyp['yseq'] = torch.ones(1, (1 + ys.size(1))).type_as(encoder_outputs).long()
                        new_hyp['yseq'][:, :ys.size(1)] = hyp['yseq']
                        new_hyp['yseq'][:, ys.size(1)] = int(local_best_ids[0, j])
                        # will be (2 x beam) hyps at most
                        hyps_best_kept.append(new_hyp)
    
                    hyps_best_kept = sorted(hyps_best_kept,
                                            key=lambda x: x['score'],
                                            reverse=True)[:beam]
                # end for hyp in hyps
                hyps = hyps_best_kept
    
                # add eos in the final loop to avoid that there are no ended hyps
                if i == maxlen - 1:
                    for hyp in hyps:
                        hyp['yseq'] = torch.cat([hyp['yseq'],
                                                 torch.ones(1, 1).fill_(self.eos_id).type_as(encoder_outputs).long()],
                                                dim=1)
    
                # add ended hypothes to a final list, and removed them from current hypothes
                # (this will be a probmlem, number of hyps < beam)
                remained_hyps = []
                for hyp in hyps:
                    if hyp['yseq'][0, -1] == self.eos_id:
                        ended_hyps.append(hyp)
                    else:
                        remained_hyps.append(hyp)
    
                hyps = remained_hyps
                # if len(hyps) > 0:
                #     print('remeined hypothes: ' + str(len(hyps)))
                # else:
                #     print('no hypothesis. Finish decoding.')
                #     break
                #
                # for hyp in hyps:
                #     print('hypo: ' + ''.join([char_list[int(x)]
                #                               for x in hyp['yseq'][0, 1:]]))
            # end for i in range(maxlen)
            nbest_hyps = sorted(ended_hyps, key=lambda x: x['score'], reverse=True)[
                         :min(len(ended_hyps), nbest)]
            # compitable with LAS implementation
            for hyp in nbest_hyps:
                hyp['yseq'] = hyp['yseq'][0].cpu().numpy().tolist()
            return nbest_hyps
    
    
    class DecoderLayer(nn.Module):
        ''' Compose with three layers '''
    
        def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
            super(DecoderLayer, self).__init__()
            self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
            self.enc_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
            self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)
    
        def forward(self, dec_input, enc_output, non_pad_mask=None, slf_attn_mask=None, dec_enc_attn_mask=None):
            dec_output, dec_slf_attn = self.slf_attn(
                dec_input, dec_input, dec_input, mask=slf_attn_mask)
            dec_output *= non_pad_mask
    
            dec_output, dec_enc_attn = self.enc_attn(
                dec_output, enc_output, enc_output, mask=dec_enc_attn_mask)
            dec_output *= non_pad_mask
    
            dec_output = self.pos_ffn(dec_output)
            dec_output *= non_pad_mask
    
            return dec_output, dec_slf_attn, dec_enc_attn
  • 相关阅读:
    js01
    js---18miniJquery
    js---17继承中方法属性的重写
    js---16继承
    js---16原型链
    js---15深拷贝浅拷贝 原型链
    js---14公有私有成员方法
    js---13 this call apply
    js---12对象创建方式,构造器,原型
    ESXi导出的CentOS7 ovf文件导入到workstation 无法打开GUI登录界面的问题解决方案
  • 原文地址:https://www.cnblogs.com/Uriel-w/p/15426155.html
Copyright © 2011-2022 走看看