zoukankan      html  css  js  c++  java
  • multiheadattention-torch

    multiheadattention

    import math
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    
    
    class ScaledDotProductAttention(nn.Module):
    
        def forward(self, query, key, value, mask=None):
            dk = query.size()[-1]
            scores = query.matmul(key.transpose(-2, -1)) / math.sqrt(dk)
            if mask is not None:
                scores = scores.masked_fill(mask == 0, -1e9)
            attention = F.softmax(scores, dim=-1)
            return attention.matmul(value)
    
    class MultiSelfAttention(nn.Module):
    
        def __init__(self, heads, d_model, dropout = 0.1):
            super().__init__()
            
            self.d_model = d_model
            self.d_k = d_model // heads
            self.h = heads
            
            self.q_linear = nn.Linear(d_model, d_model)
            self.v_linear = nn.Linear(d_model, d_model)
            self.k_linear = nn.Linear(d_model, d_model)
            
            self.dropout = nn.Dropout(dropout)
            self.out = nn.Linear(d_model, d_model)
            self.attention = ScaledDotProductAttention()
        
        def forward(self, q, k, v, mask=None):
            
            bs = q.size(0) #batch
            
            # perform linear operation and split into N heads
            k = self.k_linear(k).view(bs, -1, self.h, self.d_k)
            q = self.q_linear(q).view(bs, -1, self.h, self.d_k)
            v = self.v_linear(v).view(bs, -1, self.h, self.d_k)
            
            # transpose to get dimensions bs * N * sl * d_model
            k = k.transpose(1,2)
            q = q.transpose(1,2)
            v = v.transpose(1,2)
            
            # calculate attention using function we will define next
            scores = self.attention(q,k,v)
            # concatenate heads and put through final linear layer
            concat = scores.transpose(1,2).contiguous()
            .view(bs, -1, self.d_model)
            output = self.out(concat)
        
            return output
    
    
  • 相关阅读:
    C#设计模式-单例模式
    MVC图片上传并显示缩略图
    asp.net MVC发布iis无法加载css,js和图片
    Silverlight中获取控件中子控件
    Lambda加自定义比较器实现两个列表的合并
    MVC文件上传
    pt-osc测试
    MySQL DDL方案测试及选型.
    gh-ost测试
    gh-ost原理
  • 原文地址:https://www.cnblogs.com/lixyuan/p/12919894.html
Copyright © 2011-2022 走看看