zoukankan      html  css  js  c++  java
  • seq2seq attention

    1.seq2seq:分为encoder和decoder

      a.在decoder中,第一时刻输入的是上encoder最后一时刻的状态,如果用了双向的rnn,那么一般使用逆序的最后一个时刻的输出(网上说实验结果比较好) 

      b.每一时刻都有一个输出,即:[batch_size,  decoder_output_size],经过一个MLP后,都跟词汇表中的每一个词都对应了一个概率,即: [batch_size, vocab_size]。

      c.将每一个时刻的输出拼接起来,那么就是[batch_size, decoder_timestep, vocab_size],然后用beam search去寻找最优的解。

    2.seq2seq attention: 在decoder的时候加入了attention机制

    解释

      a.在decoder中,第一时刻输入的是上encoder最后一时刻的状态,如果用了双向的rnn,那么一般使用逆序的最后一个时刻的输出c0(网上说实验结果比较好),以后每一个时刻的输入则是上一时刻的输出与encoder的隐状态计算attention加权的结果。

      b.attention:

        1).用c0去跟encoder的所有时间步骤中的输出,进行match,即:用c0去和所有的输入求一个相似度,那么这个就是一个权值(attention的权值),含义就是当前时刻的输入是有encoder中的哪几个时刻来决定的,就是神经网络翻译中的那张经典的图。

        2).decoder第一个时间步骤c1,会输出一个向量,那么再重复1)中的步骤用c1替换c0

      c.将每一个时刻的输出拼接起来,那么就是[batch_size, decoder_timestep, vocab_size],然后用beam search去寻找最优的解。

    代码实现

       

       a. Wh  * hi 可以用线性来计算直接一个全连接层,也可以用卷积方法来计算,但是无论用哪种方法进行计算,计算出来的shape都要是跟原来encoder state一样(这里卷积有点不同,因为输入的要是4D tensor,所以输入的是[batch_size, time_step, 1, embedd_size], 采用[1, 1]的卷积核,具体可以参考pointer generator实现),得到 r1,shape [batch_size, time_step, embedd_size]

      b.  decoder端t时刻的输入和上一时刻的context vector做一个线性变换,得到输入到LSTM的输入;LSTM输出output和state,output可以用来经过一个线性变换求词表,state用于计算attention

        1)在训练时,第一的时间步骤时,context vector为0

        2)在预测时,context vector为encoder输出的state经过attention后的结果

           c.  用state经过一个线性变化,并加上bias,就等于这部分 WS * st + b,得到r2,shape [batch_size, embedd_size]

      d.  最后再经过vt * tanh(r1 + r2)得到的r3,再对r3进行求和,得到r4, shape [batch_size, time_step], 最后再经过softmax,shape [batch_size, time_step],这里要经过mask,即让为0的部分概率为0

    Seq2Seq Attention Decoder Tensorflow Api

    https://tensorflow.google.cn/api_docs/python/tf/contrib/legacy_seq2seq/attention_decoder

  • 相关阅读:
    spring自动注入--------
    Spring p 命名和c命名(不常用)
    反射笔记-----------------------------
    -----------------------spring 事务------------------------
    --------------------------------MaBatis动态sql--------------------------
    让 div中的div垂直居中的方法!!同样是抄袭来的(*^__^*)
    div中的img垂直居中的方法,最简单! 偷学来的,,,不要说我抄袭啊(*^__^*)
    关于transform-style:preserve-3d的些许明了
    转换 transform
    计数器counter
  • 原文地址:https://www.cnblogs.com/callyblog/p/9827708.html
Copyright © 2011-2022 走看看