zoukankan      html  css  js  c++  java
  • Transformer中的维度变换

    自己总结记录一下transformer中的维度变换

    对于输入

    input: [batch_size * max_sen_len]

    词嵌入矩阵

    vocab_matrix dim: [vocab_size * embedding_dim]

    位置编码

    PE(pos,2i)=sin(pos/10000^(2i/embedding_dim))
    PE(pos,2i+1)=cos(pos/10000^(2i/embedding_dim))

    encoder input embedding x = input token emb + position emb :
    [batch_size * max_sen_len * embedding_dim]

    对每一句话(句尾</s>):[ max_sen_len * embedding_dim ]

    ENCODER

    流程:

    input -> dropout ->
    (multihead SAN -> attention dropout -> residual connection -> LN -> FFN -> dropout -> RS connection-> LN) * 6 ->
    [batch_size, max_sen_len, embedding_dim]

    ---- multihead self atten ----
    WQ,WK,WV: embedding_dim * embedding_dim,
    其中WQ, WK, WV可以切分为多头WQ_i, Wk_i, WV_i, 即第二个维度 = embedding_dim//num_heads=d_k
    WQ_i,Wk_i,WV_i: embedding_dim * d_k
    q_i,k_i,v_i = x * (WQ_i,WK_i,WV_i) : max_sen_len * d_k

    weight compute:
    q_i * k_i / srqt(d_k) : [max_sen_len * max_sen_len]

    (softmax之前要对q和k做mask,把pad 0的维度置为-inf,这样softmax之后对应位置权重为0)

    softmax(q_i * k_i / sqrt(d_k) + Mask) * v_i = head_i, 在最后一个维度上做softmax
    head_i: [max_sen_len * d_k]
    Multi_head = concat num_heads of head_i = [head_1,head_2,...,head_8]: [max_sen_len * embedding_dim]
    W_outlayer : [ embedding_dim , embedding_dim ]
    #context = Multi_head * W_outlayer :[max_sen_len * embedding_dim]

    ---- add & norm ----
    [max_sen_len * embedding_dim]

    ----ffn & add & norm ----
    ffn = Relu(W_1 * x + b_1) * W_2 +b_2
    Relu = max(0,x)
    W_1 : [embedding_dim * ffn_hidden_size]
    b_1 : [1 * ffn_hidden_size ]
    W_2 : [ffn_hidden_size * embedding_dim]
    b_2 : [1 * embedding_dim]

    ---- enc out ----
    [batch_size, max_sen_len, embedding_dim]

    DECODER

    流程:

    decoder input -> droput ->
    (masked multihead self atten -> attention dropout -> RS connection-> LN ->
    multihead self atten -> dropout -> RS connection-> LN ->
    FFN -> dropout -> RS connection-> LN) *6 ->
    [batch_size, max_sen_len, vocab_size]

    decoder input embedding y = input token emb + position emb :
    [ batch_size * max_sen_len, embedding_dim]
    对每一句话y(要添加起始符号<s>) : [ max_sen_len * embedding_dim ]

    ENCODER的输出给每一层DECODER
    ---- masked multihead self atten ----
    上三角矩阵置为-inf
    q,k 来自encoder输出:[max_sen_len, embedding_dim]
    q_i,k_i,v_i = y * WQ_i,WK_i,WV_i : [max_sen_len * d_k]

    weight compute:
    q_i * k_i / srqt(d_k) : [max_sen_len * max_sen_len]

    softmax(q_i * k_i / sqrt(d_k) + Mask) * v_i = head_i : [max_sen_len * d_k]
    Multi_head = concat num_heads of head_i = [head_1,head_2,...,head_8]: [max_sen_len * embedding_dim]
    W_outlayer : [ embedding_dim , embedding_dim ]
    #context = Multi_head * W_outlayer :[max_sen_len * embedding_dim]

    ---- multihead self att ----
    维度变换同上
    [max_sen_len * embedding_dim]

    ---- add & norm ----
    [max_sen_len * embedding_dim]

    ----ffn & add & norm ----
    ffn = Relu(W_1 * y + b_1) * W_2 +b_2
    Relu = max(0,y)
    W_1 : [embedding_dim * ffn_hidden_size]
    b_1 : [1 * ffn_hidden_size]
    W_2 : [ffn_hidden_size * vocab_size]
    b_2 : [1 * vocab_size]

    [batch_size, max_sen_len, vocab_size]

    ---- dec out ----
    [batch_size, max_sen_len, vocab_size]
    decoder输出隐藏层变量,先乘以线性矩阵,再在最后一维做softmax(vocab_size维),得到词典库上的概率分布,
    输出最大的概率,与真实标签进行交叉熵损失的计算,汇总一句话中每个的损失,优化,训练
  • 相关阅读:
    <script>元素
    女朋友问什么是动态规划,应该怎么回答?
    从输入URL到页面展示,这中间都发生了什么?
    TypeScript之父:JS不是竞争对手,曾在惧怕开源的微软文化中艰难求生
    Flash 终将谢幕:微软将于年底停止对 Flash 的支持
    尤雨溪:TypeScript不会取代JavaScript
    JVM参数设置、分析(转发)
    -XX:PermSize -XX:MaxPermSize 永久区参数设置
    堆的分配参数
    -Xmx 和 –Xms 设置最大堆和最小堆
  • 原文地址:https://www.cnblogs.com/yh-blog/p/15115253.html
Copyright © 2011-2022 走看看