zoukankan      html  css  js  c++  java
  • 动手学pytorch-Transformer代码实现

    Transformer代码实现

    1.Masked softmax
    2.Multi heads attention
    3.Position wise FFN
    4.Add and Norm
    5.Position encoding
    6.Encoder block
    7.Transformer Encoder
    8.Decoder block
    9.Transformer Decoder

    1.Masked softmax

    def SequenceMask(X, X_len,value=-1e6):
        maxlen = X.size(1)
        mask = torch.arange((maxlen),dtype=torch.float, device=X.device)[None, :] >= X_len[:, None]   
        X[mask]=value
        return X
        
    def masked_softmax(X, valid_length):
        # X: 3-D tensor, valid_length: 1-D or 2-D tensor
        softmax = nn.Softmax(dim=-1)
        if valid_length is None:
            return softmax(X)
        else:
            shape = X.shape
            if valid_length.dim() == 1:
                try:
                    valid_length = torch.FloatTensor(valid_length.numpy().repeat(shape[1], axis=0))#[2,2,3,3]
                except:
                    valid_length = torch.FloatTensor(valid_length.cpu().numpy().repeat(shape[1], axis=0))#[2,2,3,3]
            else:
                valid_length = valid_length.reshape((-1,))
    #         print(valid_length.device)
            # fill masked elements with a large negative, whose exp is 0
            X = SequenceMask(X.reshape((-1, shape[-1])), valid_length.to(X.device))
     
            return softmax(X).reshape(shape)
    

    2.Multi heads attention

    class MultiHeadAttention(nn.Module):
        def __init__(self, input_size, hidden_size, num_heads, dropout, **kwargs):
            super(MultiHeadAttention, self).__init__(**kwargs)
            self.num_heads = num_heads
            self.attention = DotProductAttention(dropout)
            self.wq = nn.Linear(input_size, hidden_size, bias=False)
            self.wk = nn.Linear(input_size, hidden_size, bias=False)
            self.wv = nn.Linear(input_size, hidden_size, bias=False)
            self.wo = nn.Linear(hidden_size, hidden_size, bias=False)
            
        def forward(self, query, key, value, valid_length):
            query = transpose_qkv(self.wq(query), self.num_heads)
            key      = transpose_qkv(self.wk(key), self.num_heads)
            value  = transpose_qkv(self.wv(value), self.num_heads)
            valid_length = handle_valid_length(valid_length, self.num_heads)
            output = self.attention(query, key, value, valid_length)
            output_concat = transpose_output(output, self.num_heads)
            return self.wo(output_concat)
    

    3.Position wise FFN

    class PositionWiseFFN(nn.Module):
        def __init__(self, input_size, ffn_hidden_size, hidden_size_out, **kwargs):
            super(PositionWiseFFN, self).__init__(**kwargs)
            self.ffn_1 = nn.Linear(input_size, ffn_hidden_size)
            self.ffn_2 = nn.Linear(ffn_hidden_size, hidden_size_out)
        
        def forward(self, X):
            return self.ffn_2(F.relu(self.ffn_1(X)))
    

    4.Add and Norm

    class AddNorm(nn.Module):
        def __init__(self, hidden_size, dropout, **kwargs):
            super(AddNorm, self).__init__(**kwargs)
            self.dropout = nn.Dropout(dropout)
            self.norm = nn.LayerNorm(hidden_size)
            
        def forward(self, X, Y):
            return self.norm(self.dropout(Y) + X)
    

    5.Position encoding

    class PositionalEncoding(nn.Module):
        def __init__(self, embed_size, dropout, max_len=1000):
            super(PositionalEncoding, self).__init__()
            self.dropout = nn.Dropout(dropout)
            self.P = np.zeros((1, max_len, embed_size))
            X = np.arange(0, max_len).reshape(-1, 1) / np.power(10000, np.arange(0, embed_size, 2) / embed_size)
            self.P[:, :, 0::2] = np.sin(X)
            self.P[:, :, 1::2] = np.cos(X)
            self.P = torch.FloatTensor(self.P)
            
        def forward(self, X):
            if X.is_cuda and not self.P.is_cuda:
                self.P = self.P.cuda()
            X = X + self.P[:, :X.shape[1], :]
            return self.dropout(X)
    

    6.Encoder Block

    class EncoderBlock(nn.Module):
        def __init__(self, embed_size, ffn_hidden_size, num_heads, dropout, **kwargs):
            super(EncoderBlock, self).__init__(**kwargs)
            self.attention = MultiHeadAttention(embed_size, embed_size, num_heads, dropout)
            self.add_norm1 = AddNorm(embed_size, dropout)
            self.ffn = PositionWiseFFN(embed_size, ffn_hidden_size, embed_size)
            self.add_norm2 = AddNorm(embed_size, dropout)
        
        def forward(self, X, valid_length):
            Y = self.add_norm1(X, self.attention(X, X, X, valid_length))
            return self.add_norm2(Y, self.ffn(Y))
    

    7.Transformer encoder

    class TransformerEncoder(Encoder):
        def __init__(self, vocab_size, embed_size, ffn_hidden_size, num_heads, num_layers, dropout, **kwargs):
            super(TransformerEncoder, self).__init__(**kwargs)
            self.embed_size = embed_size
            self.embedding = nn.Embedding(vocab_size, embed_size)
            self.pos_encoding = PositionalEncoding(embed_size, dropout)
            self.blocks = nn.ModuleList()
            for i in range(num_layers):
                self.blocks.append( EncoderBlock(embed_size, ffn_hidden_size, num_heads, dropout))
                
        def forward(self, X, valid_length, *args):
            X = self.pos_encoding(self.embedding(X) * math.sqrt(self.embed_size))
            for block in self.blocks:
                X = block(X, valid_length)
            return X
    

    8.Decoder block

    class DecoderBlock(nn.Module):
        def __init__(self, embed_size, ffn_hidden_size, num_heads, dropout, i, **kwargs):
            super(DecoderBlock, self).__init__(**kwargs)
            self.i = i
            self.atten1 = MultiHeadAttention(embed_size, embed_size, num_heads, dropout)
            self.add_norm1 = AddNorm(embed_size, dropout)
            self.atten2 = MultiHeadAttention(embed_size, embed_size, num_heads, dropout)
            self.add_norm2 = AddNorm(embed_size, dropout)
            self.ffn = PositionWiseFFN(embed_size, ffn_hidden_size, embed_size)
            self.add_norm3 = AddNorm(embed_size, dropout)
    
        def forward(self, X, state):
            enc_outputs, enc_valid_length = state[0], state[1]
            if state[2][self.i] is None:
                key_value = X
            else:
                key_value = torch.cat((state[2][self.i], X), dim=1)
            state[2][self.i] = key_value
    
            if self.training:
                batch_size, seq_len, _ = X.shape
                valid_length = torch.FloatTensor(np.tile(np.arange(1, seq_len+1), (batch_size, 1)))
                valid_length = valid_length.to(X.device)
            else:
                valid_length = None
            X2 = self.atten1(X, key_value, key_value, valid_length)
            Y = self.add_norm1(X, X2)
            Y2 = self.atten2(Y, enc_outputs, enc_outputs, enc_valid_length)
            Z = self.add_norm2(Y, Y2)
            return self.add_norm3(Z, self.ffn(Z)), state
    

    9.Transformer decoder

    class TransformerDecoder(Decoder):
        def __init__(self, vocab_size, embed_size, ffn_hidden_size, num_heads, num_layers, dropout, **kwargs):
            super(TransformerDecoder, self).__init__(**kwargs)
            self.embed_size = embed_size
            self.num_layers = num_layers
            self.embedding = nn.Embedding(vocab_size, embed_size)
            self.pos_encoding = PositionalEncoding(embed_size, dropout)
            self.blocks = nn.ModuleList()
            for i in range(num_layers):
                self.blocks.append(DecoderBlock(embed_size, ffn_hidden_size, num_heads, dropout, i))
            self.dense = nn.Linear(embed_size, vocab_size)
    
        def init_state(self, enc_outputs, enc_valid_length, *args):
            return [enc_outputs, enc_valid_length, [None] * self.num_layers]
        
        def forward(self, X, state):
            X = self.pos_encoding(self.embedding(X) * math.sqrt(self.embed_size))
            for block in self.blocks:
                X, state = block(X, state)
            return self.dense(X), state
    
  • 相关阅读:
    android bitmap 防止 内存溢出
    析android应用增量升级(差分升级)
    android 数据库查询中使用索引-大幅提高数据库操作速度
    android 插件框架
    android悬浮view-FloatingView
    android 消息分发框架 otto
    修改了USER_AGENT还是不能抓取到数据怎么办?
    找不到指定文件scrapy,在pycharm中运行程序出现错误
    scrapy抓取豆瓣网信息时报错提醒403[scrapy.spidermiddlewares.httperror] INFO
    slidetoshutdown电脑滑动关机命令失效了怎么办?
  • 原文地址:https://www.cnblogs.com/54hys/p/12325182.html
Copyright © 2011-2022 走看看