zoukankan      html  css  js  c++  java
  • tensorflow学习笔记(三十九):双向rnn

    tensorflow 双向 rnn

    如何在tensorflow中实现双向rnn

    单层双向rnn

    这里写图片描述
    单层双向rnn (cs224d)


    tensorflow中已经提供了双向rnn的接口,它就是tf.nn.bidirectional_dynamic_rnn(). 我们先来看一下这个接口怎么用.

     1 bidirectional_dynamic_rnn(
     2     cell_fw, #前向 rnn cell
     3     cell_bw, #反向 rnn cell
     4     inputs, #输入序列.
     5     sequence_length=None,# 序列长度
     6     initial_state_fw=None,#前向rnn_cell的初始状态
     7     initial_state_bw=None,#反向rnn_cell的初始状态
     8     dtype=None,#数据类型
     9     parallel_iterations=None,
    10     swap_memory=False,
    11     time_major=False,
    12     scope=None
    13 )

    返回值:一个tuple(outputs, outputs_states), 其中,outputs是一个tuple(outputs_fw, outputs_bw). 关于outputs_fwoutputs_bw,如果time_major=True则它俩也是time_major的,vice versa. 如果想要concatenate的话,直接使用tf.concat(outputs, 2)即可.

    如何使用: 
    bidirectional_dynamic_rnn 在使用上和 dynamic_rn

    n是非常相似的.
    
    定义前向和反向rnn_cell
    定义前向和反向rnn_cell的初始状态
    准备好序列
    调用bidirectional_dynamic_rnn
    import tensorflow as tf
    from tensorflow.contrib import rnn
    cell_fw = rnn.LSTMCell(10)
    cell_bw = rnn.LSTMCell(10)
    initial_state_fw = cell_fw.zero_state(batch_size)
    initial_state_bw = cell_bw.zero_state(batch_size)
    seq = ...
    seq_length = ...
    (outputs, states)=tf.nn.bidirectional_dynamic_rnn(cell_fw, cell_bw, seq,
     seq_length, initial_state_fw,initial_state_bw)
    out = tf.concat(outputs, 2)
    View Code
    # ....

    多层双向rnn

    这里写图片描述
    多层双向rnn(cs224d)

    单层双向rnn可以通过上述方法简单的实现,但是多层的双向rnn就不能使将MultiRNNCell传给bidirectional_dynamic_rnn了. 
    想要知道为什么,我们需要看一下bidirectional_dynamic_rnn的源码片段.

    1 with vs.variable_scope(scope or "bidirectional_rnn"):
    2   # Forward direction
    3   with vs.variable_scope("fw") as fw_scope:
    4     output_fw, output_state_fw = dynamic_rnn(
    5         cell=cell_fw, inputs=inputs, sequence_length=sequence_length,
    6         initial_state=initial_state_fw, dtype=dtype,
    7         parallel_iterations=parallel_iterations, swap_memory=swap_memory,
    8         time_major=time_major, scope=fw_scope)

    这只是一小部分代码,但足以看出,bi-rnn实际上是依靠dynamic-rnn实现的,如果我们使用MuitiRNNCell的话,那么每层之间不同方向之间交互就被忽略了.所以我们可以自己实现一个工具函数,通过多次调用bidirectional_dynamic_rnn来实现多层的双向RNN 这是我对多层双向RNN的一个精简版的实现,如有错误,欢迎指出

    bidirectional_dynamic_rnn源码一探

    上面我们已经看到了正向过程的代码实现,下面来看一下剩下的反向部分的实现. 
    其实反向的过程就是做了两次reverse 
    1. 第一次reverse:将输入序列进行reverse,然后送入dynamic_rnn做一次运算. 
    2. 第二次reverse:将上面dynamic_rnn返回的outputs进行reverse,保证正向和反向输出的time是对上的.

     1 def _reverse(input_, seq_lengths, seq_dim, batch_dim):
     2   if seq_lengths is not None:
     3     return array_ops.reverse_sequence(
     4         input=input_, seq_lengths=seq_lengths,
     5         seq_dim=seq_dim, batch_dim=batch_dim)
     6   else:
     7     return array_ops.reverse(input_, axis=[seq_dim])
     8 
     9 with vs.variable_scope("bw") as bw_scope:
    10   inputs_reverse = _reverse(
    11       inputs, seq_lengths=sequence_length,
    12       seq_dim=time_dim, batch_dim=batch_dim)
    13   tmp, output_state_bw = dynamic_rnn(
    14       cell=cell_bw, inputs=inputs_reverse, sequence_length=sequence_length,
    15       initial_state=initial_state_bw, dtype=dtype,
    16       parallel_iterations=parallel_iterations, swap_memory=swap_memory,
    17       time_major=time_major, scope=bw_scope)
    18 
    19 output_bw = _reverse(
    20   tmp, seq_lengths=sequence_length,
    21   seq_dim=time_dim, batch_dim=batch_dim)
    22 
    23 outputs = (output_fw, output_bw)
    24 output_states = (output_state_fw, output_state_bw)
    25 
    26 return (outputs, output_states)

    tf.reverse_sequence

    对序列中某一部分进行反转

    1 reverse_sequence(
    2     input,#输入序列,将被reverse的序列
    3     seq_lengths,#1Dtensor,表示输入序列长度
    4     seq_axis=None,# 哪维代表序列
    5     batch_axis=None, #哪维代表 batch
    6     name=None,
    7     seq_dim=None,
    8     batch_dim=None
    9 )

    官网上的例子给的非常好,这里就直接粘贴过来:

     1 # Given this:
     2 batch_dim = 0
     3 seq_dim = 1
     4 input.dims = (4, 8, ...)
     5 seq_lengths = [7, 2, 3, 5]
     6 
     7 # then slices of input are reversed on seq_dim, but only up to seq_lengths:
     8 output[0, 0:7, :, ...] = input[0, 7:0:-1, :, ...]
     9 output[1, 0:2, :, ...] = input[1, 2:0:-1, :, ...]
    10 output[2, 0:3, :, ...] = input[2, 3:0:-1, :, ...]
    11 output[3, 0:5, :, ...] = input[3, 5:0:-1, :, ...]
    12 
    13 # while entries past seq_lens are copied through:
    14 output[0, 7:, :, ...] = input[0, 7:, :, ...]
    15 output[1, 2:, :, ...] = input[1, 2:, :, ...]
    16 output[2, 3:, :, ...] = input[2, 3:, :, ...]
    17 output[3, 2:, :, ...] = input[3, 2:, :, ...]

    例二:

     1 # Given this:
     2 batch_dim = 2
     3 seq_dim = 0
     4 input.dims = (8, ?, 4, ...)
     5 seq_lengths = [7, 2, 3, 5]
     6 
     7 # then slices of input are reversed on seq_dim, but only up to seq_lengths:
     8 output[0:7, :, 0, :, ...] = input[7:0:-1, :, 0, :, ...]
     9 output[0:2, :, 1, :, ...] = input[2:0:-1, :, 1, :, ...]
    10 output[0:3, :, 2, :, ...] = input[3:0:-1, :, 2, :, ...]
    11 output[0:5, :, 3, :, ...] = input[5:0:-1, :, 3, :, ...]
    12 
    13 # while entries past seq_lens are copied through:
    14 output[7:, :, 0, :, ...] = input[7:, :, 0, :, ...]
    15 output[2:, :, 1, :, ...] = input[2:, :, 1, :, ...]
    16 output[3:, :, 2, :, ...] = input[3:, :, 2, :, ...]
    17 output[2:, :, 3, :, ...] = input[2:, :, 3, :, ...]
  • 相关阅读:
    Educational Codeforces Round 88 (Rated for Div. 2) D. Yet Another Yet Another Task(枚举/最大连续子序列)
    Educational Codeforces Round 88 (Rated for Div. 2) A. Berland Poker(数学)
    Educational Codeforces Round 88 (Rated for Div. 2) E. Modular Stability(数论)
    Educational Codeforces Round 88 (Rated for Div. 2) C. Mixing Water(数学/二分)
    Codeforces Round #644 (Div. 3)
    Educational Codeforces Round 76 (Rated for Div. 2)
    Educational Codeforces Round 77 (Rated for Div. 2)
    Educational Codeforces Round 87 (Rated for Div. 2)
    AtCoder Beginner Contest 168
    Codeforces Round #643 (Div. 2)
  • 原文地址:https://www.cnblogs.com/silence-tommy/p/8058333.html
Copyright © 2011-2022 走看看