zoukankan      html  css  js  c++  java
  • transformer代码笔记----encoder.py

    import torch.nn as nn
    from .attention import MultiHeadAttention   #引进多头注意力模块
    from .module import PositionalEncoding, PositionwiseFeedForward  #位置编码和前馈网络
    from .utils import get_non_pad_mask, get_attn_pad_mask  #padding mask:填充补齐使得输入长度相同。attention mask:
    
    
    class Encoder(nn.Module):
        """Encoder of Transformer including self-attention and feed forward.
        """
    
        def __init__(self, d_input=320, n_layers=6, n_head=8, d_k=64, d_v=64,
                     d_model=512, d_inner=2048, dropout=0.1, pe_maxlen=5000):
            super(Encoder, self).__init__()
            # parameters
            self.d_input = d_input   #输入维度
            self.n_layers = n_layers #编码解码层数
            self.n_head = n_head     #自注意力头数
            self.d_k = d_k           #键矩阵维度
            self.d_v = d_v           #值矩阵维度
            self.d_model = d_model   #模型维度
            self.d_inner = d_inner   #前馈网络隐层神经元个数(维度)
            self.dropout_rate = dropout #信息漏失率
            self.pe_maxlen = pe_maxlen  #位置编码最大长度
    
            # use linear transformation with layer norm to replace input embedding
            self.linear_in = nn.Linear(d_input, d_model) #全连接,输入为batch size 和size
            self.layer_norm_in = nn.LayerNorm(d_model)   #层归一化
            self.positional_encoding = PositionalEncoding(d_model, max_len=pe_maxlen) #位置编码
            self.dropout = nn.Dropout(dropout) #dropout
    
            self.layer_stack = nn.ModuleList([
                EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
                for _ in range(n_layers)])   #实现n_layers次编码器
            #nn.ModuleList,它是一个储存不同 module,并自动将每个 module 的 parameters 添加到网络之中的容器。
    
        def forward(self, padded_input, input_lengths, return_attns=False):
            """
            Args:
                padded_input: N x T x D
                input_lengths: N
            Returns:
                enc_output: N x T x H
            """
            enc_slf_attn_list = []
    
            # Prepare masks
            non_pad_mask = get_non_pad_mask(padded_input, input_lengths=input_lengths) #没有填充前
            length = padded_input.size(1)  #获得填充维度
            slf_attn_mask = get_attn_pad_mask(padded_input, input_lengths, length) #获得填充结果
    
            # Forward
            # 进入编码器前对数据的处理
            enc_output = self.dropout(
                self.layer_norm_in(self.linear_in(padded_input)) +
                self.positional_encoding(padded_input)) 
                #对数据线性变换(将320维的输入变为512维)后归一化,然后加上位置编码后的数据进行dropout
    
            for enc_layer in self.layer_stack:  #进入编码器
                enc_output, enc_slf_attn = enc_layer( 
                    enc_output,
                    non_pad_mask=non_pad_mask,
                    slf_attn_mask=slf_attn_mask)#经过编码器输出编码结果和注意力
                if return_attns: #默认不对每层的注意力形成列表形式
                    enc_slf_attn_list += [enc_slf_attn]
    
            if return_attns: #默认为false
                return enc_output, enc_slf_attn_list
            return enc_output, #返回最后层编码器输出
    
    
    class EncoderLayer(nn.Module):
        """Compose with two sub-layers.
            1. A multi-head self-attention mechanism
            2. A simple, position-wise fully connected feed-forward network.
        """
    
        def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
            super(EncoderLayer, self).__init__()
            self.slf_attn = MultiHeadAttention(
                n_head, d_model, d_k, d_v, dropout=dropout)  #多头注意力实例化
            self.pos_ffn = PositionwiseFeedForward(
                d_model, d_inner, dropout=dropout)           #前馈网络实例化
    
        def forward(self, enc_input, non_pad_mask=None, slf_attn_mask=None):
            enc_output, enc_slf_attn = self.slf_attn(
                enc_input, enc_input, enc_input, mask=slf_attn_mask) #获得多头注意力的输出
            enc_output *= non_pad_mask     #防止经过注意力层后数据的长度发生变化
    
            enc_output = self.pos_ffn(enc_output)   #前馈网络的输出
            enc_output *= non_pad_mask
    
            return enc_output, enc_slf_attn    #返回一个编码器的输出
  • 相关阅读:
    Delphi XE5 android 蓝牙通讯传输
    Delphi XE5 android toast
    Delphi XE5 android openurl(转)
    Delphi XE5 如何设计并使用FireMonkeyStyle(转)
    Delphi XE5 android 捕获几个事件
    Delphi XE5 android listview
    Delphi XE5 android 黑屏的临时解决办法
    Delphi XE5 android popumenu
    Delphi XE5 android 获取网络状态
    Delphi XE5 android 获取电池电量
  • 原文地址:https://www.cnblogs.com/Uriel-w/p/15426146.html
Copyright © 2011-2022 走看看