zoukankan      html  css  js  c++  java
  • 多头Attention 和 自注意力机制

    这个多头attention确实挺搞的,这个东西绕来绕去,看torch的文档也看不懂,看源码也迷迷糊糊的,可能我的智商就是不够吧。。。枯了

    论文里的公式求法,可以看到它因为是self-multiheadsAttention。多头自注意力机制,所以它这里的Q K V 实际上是同一个东西,也就是最后一维都是相同的。

    为什么这里可以直接concat起来,是因为它将Q、K、V最后一维都进行了切割,也就是说,它的多头attention不是说使用多个attention weight,而是说对不同part部分进行attention。比如论文将Q、K、V最后一个维度切成了8块,它的8头attention,就是每个attention就对这一块部分进行attention机制,最后进行concat。这也是一个有意思的点,这样就直接用点积attention来一次矩阵乘法就行了。

    image-20211119095811841

    这里有个参考的回答:

    为什么切割方式求attention

    这里有两张参考的图片:

    image-20211119105801580

    img

    torch 文档

    image-20211119105306596

    这里的embed_dim 就是后面Q的dim(最后一维)也就是词向量的维度,这是模型输出的维度,默认q、k、v的最后维度一致。

    image-20211119105942893

    key_padding_mask 是padding mask 是掩key的

    attn_mask 是掩key_value pair的。

    这么说可能很难理解,key_padding_mask就是说句子序列中有多少个padding,这些padding是不要的。但是attn_mask 是用来说,我不能提前看到后面的词。(这个还是在自注意力机制用到),因为transformer的decoder第一层的自注意力层不能看到未来的词。

    自己实现多头注意力机制

    import torch
    import torch.nn as nn
    import math
    from d2l import torch as d2l
    class MultiAttention(nn.Module):
        def __init__(self, embed_dim, num_heads, qdim=None, kdim=None, vdim=None, hdim=256, dropout=0.0) -> None:
            super(MultiAttention, self).__init__()
            self.attention = d2l.DotProductAttention(dropout)
            self.num_heads = num_heads
            nn.MultiheadAttention()
            # 先做一个全连接层好把Q、K、V不同维度转为同一维度
            self.W_q = nn.Linear(qdim, embed_dim)
            self.W_k = nn.Linear(kdim, embed_dim)
            self.W_v = nn.Linear(vdim, embed_dim)
            self.W_o = nn.Linear(embed_dim, emded_dim)
            
        def forward(self, Q, K, V):
            # 注意这里的Q 的shape (batchsize, qn, qdim)
            # K (batchsize, kvn, kdim)
            # V (batchsize, kvn, vdim)
            Q = self.trans(self.W_q(Q), self.num_heads)
            K = self.trans(self.W_k(K), self.num_heads)
            V = self.trans(self.W_v(V), self.num_heads)
            # Q (batchsize *numheads, qn, embed_dim/num_heads)
            output = self.attention(Q, K, V)
            # output shape (batchsize*num_heads, qn, kvn, embed/num_heads)
            # 这里没有返回attentionweight,但attentionweight的shape (batchsize*num_heads, qn, kvn)
            # output最后一维的embed/num_heads,是因为我们将V的最后一维切割了。
            output = self.retrans(output, self.num_heads) 
            return self.W_o(output)
        
        def trans(self, X, num_heads):
            # X shape (batchsize, 查询或者‘键值对’数, embed_dim)
            X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
            X = X.permute(0, 2, 1, 3)
            X = X.reshape(-1, X.shape[2], X.shape[3])
            return X
        
        def retrans(self, X, num_heads):
            # X shape (batchsize*num_heads, 查询或者‘键值对’数, embed_dim/num_heads)
            X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
            X = X.permute(0, 2, 1, 3)
            X = X.reshape(X.shape[0], X.shape[1], -1)
            return X
    
    attention = MultiAttention(256, 2, 64, 64, 64)
    q= k= v= torch.ones((32, 35, 64))
    s = attention(q, k, v)
    
    

    这里我其实写的很不标准,因为几个全连接搞得挺混乱的。但其实思想也是一致的。

    自注意力机制

    image-20211119110854322

    可以看到RNN是很有时序性的,它是要求一个一个输入。CNN也可以保留一定的时序性,因为卷积核的感受野可以保留部分时序信息。但是self-attention机制是完全没有时序性的,它一次就可以看完全部。

    位置编码

    这里就引入了位置编码这个概念:

    X = X + P其中P就是位置编码,对应的值:

    \[\begin{aligned} p_{i, 2 j} &=\sin \left(\frac{i}{10000^{2 j / d}}\right) \\ p_{i, 2 j+1} &=\cos \left(\frac{i}{10000^{2 j / d}}\right) \end{aligned} \]

  • 相关阅读:
    面试题-Java多线程基础、实现工具和可见性保证(新更新版)
    Linux脚本-使用jar自动替换配置文件
    让ie兼容opacity属性的方法
    rgba兼容ie
    ie6的又一个变态问题
    长为112px的td里面有一个块span,居中时引起的兼容问题
    ie8中position不显示的问题
    haslayout
    CSS中zoom作用
    Vue学习笔记一
  • 原文地址:https://www.cnblogs.com/kalicener/p/15576275.html
Copyright © 2011-2022 走看看