zoukankan      html  css  js  c++  java
  • multiheadattention-torch

    multiheadattention

    import math
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    
    
    class ScaledDotProductAttention(nn.Module):
    
        def forward(self, query, key, value, mask=None):
            dk = query.size()[-1]
            scores = query.matmul(key.transpose(-2, -1)) / math.sqrt(dk)
            if mask is not None:
                scores = scores.masked_fill(mask == 0, -1e9)
            attention = F.softmax(scores, dim=-1)
            return attention.matmul(value)
    
    class MultiSelfAttention(nn.Module):
    
        def __init__(self, heads, d_model, dropout = 0.1):
            super().__init__()
            
            self.d_model = d_model
            self.d_k = d_model // heads
            self.h = heads
            
            self.q_linear = nn.Linear(d_model, d_model)
            self.v_linear = nn.Linear(d_model, d_model)
            self.k_linear = nn.Linear(d_model, d_model)
            
            self.dropout = nn.Dropout(dropout)
            self.out = nn.Linear(d_model, d_model)
            self.attention = ScaledDotProductAttention()
        
        def forward(self, q, k, v, mask=None):
            
            bs = q.size(0) #batch
            
            # perform linear operation and split into N heads
            k = self.k_linear(k).view(bs, -1, self.h, self.d_k)
            q = self.q_linear(q).view(bs, -1, self.h, self.d_k)
            v = self.v_linear(v).view(bs, -1, self.h, self.d_k)
            
            # transpose to get dimensions bs * N * sl * d_model
            k = k.transpose(1,2)
            q = q.transpose(1,2)
            v = v.transpose(1,2)
            
            # calculate attention using function we will define next
            scores = self.attention(q,k,v)
            # concatenate heads and put through final linear layer
            concat = scores.transpose(1,2).contiguous()
            .view(bs, -1, self.d_model)
            output = self.out(concat)
        
            return output
    
    
  • 相关阅读:
    Jmeter的两种录制脚本的方式
    【.NET】设置EntityFramework中decimal类型数据精度 [转]
    vscode格式化vue不换行
    mysql5.7 noinstall 安装 【转载】
    配置STP、RSTP以及负载均衡
    配置3层交换机VLAN间通信
    配置单臂路由
    配置DTP
    配置trunk
    配置VLAN
  • 原文地址:https://www.cnblogs.com/lixyuan/p/12919894.html
Copyright © 2011-2022 走看看