zoukankan      html  css  js  c++  java
  • transformer代码笔记----encoder.py

    import torch.nn as nn
    from .attention import MultiHeadAttention   #引进多头注意力模块
    from .module import PositionalEncoding, PositionwiseFeedForward  #位置编码和前馈网络
    from .utils import get_non_pad_mask, get_attn_pad_mask  #padding mask:填充补齐使得输入长度相同。attention mask:
    
    
    class Encoder(nn.Module):
        """Encoder of Transformer including self-attention and feed forward.
        """
    
        def __init__(self, d_input=320, n_layers=6, n_head=8, d_k=64, d_v=64,
                     d_model=512, d_inner=2048, dropout=0.1, pe_maxlen=5000):
            super(Encoder, self).__init__()
            # parameters
            self.d_input = d_input   #输入维度
            self.n_layers = n_layers #编码解码层数
            self.n_head = n_head     #自注意力头数
            self.d_k = d_k           #键矩阵维度
            self.d_v = d_v           #值矩阵维度
            self.d_model = d_model   #模型维度
            self.d_inner = d_inner   #前馈网络隐层神经元个数(维度)
            self.dropout_rate = dropout #信息漏失率
            self.pe_maxlen = pe_maxlen  #位置编码最大长度
    
            # use linear transformation with layer norm to replace input embedding
            self.linear_in = nn.Linear(d_input, d_model) #全连接,输入为batch size 和size
            self.layer_norm_in = nn.LayerNorm(d_model)   #层归一化
            self.positional_encoding = PositionalEncoding(d_model, max_len=pe_maxlen) #位置编码
            self.dropout = nn.Dropout(dropout) #dropout
    
            self.layer_stack = nn.ModuleList([
                EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
                for _ in range(n_layers)])   #实现n_layers次编码器
            #nn.ModuleList,它是一个储存不同 module,并自动将每个 module 的 parameters 添加到网络之中的容器。
    
        def forward(self, padded_input, input_lengths, return_attns=False):
            """
            Args:
                padded_input: N x T x D
                input_lengths: N
            Returns:
                enc_output: N x T x H
            """
            enc_slf_attn_list = []
    
            # Prepare masks
            non_pad_mask = get_non_pad_mask(padded_input, input_lengths=input_lengths) #没有填充前
            length = padded_input.size(1)  #获得填充维度
            slf_attn_mask = get_attn_pad_mask(padded_input, input_lengths, length) #获得填充结果
    
            # Forward
            # 进入编码器前对数据的处理
            enc_output = self.dropout(
                self.layer_norm_in(self.linear_in(padded_input)) +
                self.positional_encoding(padded_input)) 
                #对数据线性变换(将320维的输入变为512维)后归一化,然后加上位置编码后的数据进行dropout
    
            for enc_layer in self.layer_stack:  #进入编码器
                enc_output, enc_slf_attn = enc_layer( 
                    enc_output,
                    non_pad_mask=non_pad_mask,
                    slf_attn_mask=slf_attn_mask)#经过编码器输出编码结果和注意力
                if return_attns: #默认不对每层的注意力形成列表形式
                    enc_slf_attn_list += [enc_slf_attn]
    
            if return_attns: #默认为false
                return enc_output, enc_slf_attn_list
            return enc_output, #返回最后层编码器输出
    
    
    class EncoderLayer(nn.Module):
        """Compose with two sub-layers.
            1. A multi-head self-attention mechanism
            2. A simple, position-wise fully connected feed-forward network.
        """
    
        def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
            super(EncoderLayer, self).__init__()
            self.slf_attn = MultiHeadAttention(
                n_head, d_model, d_k, d_v, dropout=dropout)  #多头注意力实例化
            self.pos_ffn = PositionwiseFeedForward(
                d_model, d_inner, dropout=dropout)           #前馈网络实例化
    
        def forward(self, enc_input, non_pad_mask=None, slf_attn_mask=None):
            enc_output, enc_slf_attn = self.slf_attn(
                enc_input, enc_input, enc_input, mask=slf_attn_mask) #获得多头注意力的输出
            enc_output *= non_pad_mask     #防止经过注意力层后数据的长度发生变化
    
            enc_output = self.pos_ffn(enc_output)   #前馈网络的输出
            enc_output *= non_pad_mask
    
            return enc_output, enc_slf_attn    #返回一个编码器的输出
  • 相关阅读:
    学习笔记2
    带有循环的存储过程
    经典SQL语句大全
    关于职业的一些看法
    把dataTable表批量的写入数据库
    抽奖接口,每天只能抽奖3次,而且必须先登录才能抽奖的小程序
    调用获取学生信息的接口,保存到excel里面的小程序
    内置函数补充
    好用的模块
    网络编程
  • 原文地址:https://www.cnblogs.com/Uriel-w/p/15426146.html
Copyright © 2011-2022 走看看