zoukankan      html  css  js  c++  java
  • transformer代码笔记----transformer.py

    import torch.nn as nn
    
    from .decoder import Decoder
    from .encoder import Encoder
    
    
    class Transformer(nn.Module):  #定义类,继承父类nn.Module
        """An encoder-decoder framework only includes attention.
        """
    
        def __init__(self, encoder=None, decoder=None):  #参数encoder和decoder设置默认值None
            super(Transformer, self).__init__()          #继承父类__init__()
    
            if encoder is not None and decoder is not None:   #判断decoder和encoder是否被重新赋值
                self.encoder = encoder
                self.decoder = decoder
    
                for p in self.parameters():  #获取网络参数
                    if p.dim() > 1:
                        nn.init.xavier_uniform_(p)  #参数初始化,torch.nn.init.xavier_uniform_是一个服从均匀分布的Glorot初始化器
            # else:
            #     self.encoder = Encoder()    #对全局变量赋值
            #     self.decoder = Decoder()
    
        def forward(self, padded_input, input_lengths, padded_target):  #编码器中的前向传播
            """
            Args:
                padded_input: B x Ti x D   表示编码器输入时数据结构
                其中B(一维向量):批量中每个音频的具体长度;Ti:该批量中音频的最大长度;
                input_lengths: B   每个音频的具体长度,假设批量大小为32,则B可表示为[2,3,45,6....],维度32
                padded_targets: B x To   表示解码器的输入数据结构,这里的B和上面的B不同,因为编码器中是音频的输入,解码器中的输入是字符
            """
            encoder_padded_outputs, *_ = self.encoder(padded_input, input_lengths)
            # pred is score before softmax
            pred, gold, *_ = self.decoder(padded_target, encoder_padded_outputs,
                                          input_lengths)
            return pred, gold
    
        def recognize(self, input, input_length, char_list, args):   #解码器中的识别过程
            """Sequence-to-Sequence beam search, decode one utterence now.
            Args:
                input: T x D
                char_list: list of characters
                args: args.beam
            Returns:
                nbest_hyps:
            """
            encoder_outputs, *_ = self.encoder(input.unsqueeze(0), input_length)
            nbest_hyps = self.decoder.recognize_beam(encoder_outputs[0],
                                                     char_list,
                                                     args)
            return nbest_hyps
  • 相关阅读:
    poj 3320 Jessica's Reading Problem
    uva 120 C
    vim使用教程-转自
    2015 俄罗斯网络赛 D. Boulevard
    HTML转义字符大全
    介绍个好点的,JAVA技术群
    JAVA学习路线
    linux常用命令大全(转)好东西要分享
    Jqprint 轻量级页面打印插件
    hadoop集群搭建
  • 原文地址:https://www.cnblogs.com/Uriel-w/p/15426153.html
Copyright © 2011-2022 走看看