zoukankan      html  css  js  c++  java
  • multiheadattention-torch

    multiheadattention

    import math
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    
    
    class ScaledDotProductAttention(nn.Module):
    
        def forward(self, query, key, value, mask=None):
            dk = query.size()[-1]
            scores = query.matmul(key.transpose(-2, -1)) / math.sqrt(dk)
            if mask is not None:
                scores = scores.masked_fill(mask == 0, -1e9)
            attention = F.softmax(scores, dim=-1)
            return attention.matmul(value)
    
    class MultiSelfAttention(nn.Module):
    
        def __init__(self, heads, d_model, dropout = 0.1):
            super().__init__()
            
            self.d_model = d_model
            self.d_k = d_model // heads
            self.h = heads
            
            self.q_linear = nn.Linear(d_model, d_model)
            self.v_linear = nn.Linear(d_model, d_model)
            self.k_linear = nn.Linear(d_model, d_model)
            
            self.dropout = nn.Dropout(dropout)
            self.out = nn.Linear(d_model, d_model)
            self.attention = ScaledDotProductAttention()
        
        def forward(self, q, k, v, mask=None):
            
            bs = q.size(0) #batch
            
            # perform linear operation and split into N heads
            k = self.k_linear(k).view(bs, -1, self.h, self.d_k)
            q = self.q_linear(q).view(bs, -1, self.h, self.d_k)
            v = self.v_linear(v).view(bs, -1, self.h, self.d_k)
            
            # transpose to get dimensions bs * N * sl * d_model
            k = k.transpose(1,2)
            q = q.transpose(1,2)
            v = v.transpose(1,2)
            
            # calculate attention using function we will define next
            scores = self.attention(q,k,v)
            # concatenate heads and put through final linear layer
            concat = scores.transpose(1,2).contiguous()
            .view(bs, -1, self.d_model)
            output = self.out(concat)
        
            return output
    
    
  • 相关阅读:
    svn:ignore 设置忽略文件
    Css让文字自适应Table宽度[转]
    自学python笔记
    java代码中添加log4j日志
    maven多模块项目搭建
    js || 与&&
    java内存溢出和tomcat内存配置
    xsl:for-each中引用循环外全局变量
    quartz启动两次(tomcat)
    pymysql简单封装
  • 原文地址:https://www.cnblogs.com/lixyuan/p/12919894.html
Copyright © 2011-2022 走看看