zoukankan      html  css  js  c++  java
  • tensorflow dynamic rnn源码分析

    python3.6,tensorflow1.11

    测试代码:

    tensorflow在eager模式下进行测试,方便调试,查看中间结果

     1 import tensorflow as tf
     2 
     3 tf.enable_eager_execution()
     4 
     5 batch_size = 4 
     6 input = tf.random_normal(shape=[3, batch_size, 6], dtype=tf.float32)
     7 cell = tf.nn.rnn_cell.BasicLSTMCell(10, forget_bias=1.0, state_is_tuple=True)
     8 init_state = cell.zero_state(batch_size, dtype=tf.float32)
     9 seq_length = tf.constant([2,3,2,3],dtype=tf.int32)
    10 import pdb; pdb.set_trace()
    11 output, final_state = tf.nn.dynamic_rnn(cell, input, initial_state=init_state,sequence_length=seq_length,time_major=True) #time_major如果是True,就表示RNN的steps用第一个维度表示,建议用这个,运行速度快一点。
    12 #如果是False,那么输入的第二个维度就是steps。
    13 #如果是True,output的维度是[steps, batch_size, depth],反之就是[batch_size, max_time, depth]。就是和输入是一样的
    14 #final_state就是整个LSTM输出的最终的状态,包含c和h。c和h的维度都是[batch_size, n_hidden]

    tf.nn.dynamic_rnn在tensorflow/python/ops/rnn.py中定义,进入其中调试

      1 @tf_export("nn.dynamic_rnn")
      2 def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None,
      3                 dtype=None, parallel_iterations=None, swap_memory=False,
      4                 time_major=False, scope=None):
      5   """Creates a recurrent neural network specified by RNNCell `cell`.
      6 
      7   Performs fully dynamic unrolling of `inputs`.
      8 
      9   Example:
     10 
     11   ```python
     12   # create a BasicRNNCell
     13   rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size)
     14 
     15   # 'outputs' is a tensor of shape [batch_size, max_time, cell_state_size]
     16 
     17   # defining initial state
     18   initial_state = rnn_cell.zero_state(batch_size, dtype=tf.float32)
     19 
     20   # 'state' is a tensor of shape [batch_size, cell_state_size]
     21   outputs, state = tf.nn.dynamic_rnn(rnn_cell, input_data,
     22                                      initial_state=initial_state,
     23                                      dtype=tf.float32)
     24   ```
     25 
     26   ```python
     27   # create 2 LSTMCells
     28   rnn_layers = [tf.nn.rnn_cell.LSTMCell(size) for size in [128, 256]]
     29 
     30   # create a RNN cell composed sequentially of a number of RNNCells
     31   multi_rnn_cell = tf.nn.rnn_cell.MultiRNNCell(rnn_layers)
     32 
     33   # 'outputs' is a tensor of shape [batch_size, max_time, 256]
     34   # 'state' is a N-tuple where N is the number of LSTMCells containing a
     35   # tf.contrib.rnn.LSTMStateTuple for each cell
     36   outputs, state = tf.nn.dynamic_rnn(cell=multi_rnn_cell,
     37                                      inputs=data,
     38                                      dtype=tf.float32)
     39   ```
     40 
     41 
     42   Args:
     43     cell: An instance of RNNCell.
     44     inputs: The RNN inputs.
     45       If `time_major == False` (default), this must be a `Tensor` of shape:
     46         `[batch_size, max_time, ...]`, or a nested tuple of such
     47         elements.
     48       If `time_major == True`, this must be a `Tensor` of shape:
     49         `[max_time, batch_size, ...]`, or a nested tuple of such
     50         elements.
     51       This may also be a (possibly nested) tuple of Tensors satisfying
     52       this property.  The first two dimensions must match across all the inputs,
     53       but otherwise the ranks and other shape components may differ.
     54       In this case, input to `cell` at each time-step will replicate the
     55       structure of these tuples, except for the time dimension (from which the
     56       time is taken).
     57       The input to `cell` at each time step will be a `Tensor` or (possibly
     58       nested) tuple of Tensors each with dimensions `[batch_size, ...]`.
     59     sequence_length: (optional) An int32/int64 vector sized `[batch_size]`.
     60       Used to copy-through state and zero-out outputs when past a batch
     61       element's sequence length.  So it's more for performance than correctness.
     62     initial_state: (optional) An initial state for the RNN.
     63       If `cell.state_size` is an integer, this must be
     64       a `Tensor` of appropriate type and shape `[batch_size, cell.state_size]`.
     65       If `cell.state_size` is a tuple, this should be a tuple of
     66       tensors having shapes `[batch_size, s] for s in cell.state_size`.
     67     dtype: (optional) The data type for the initial state and expected output.
     68       Required if initial_state is not provided or RNN state has a heterogeneous
     69       dtype.
     70     parallel_iterations: (Default: 32).  The number of iterations to run in
     71       parallel.  Those operations which do not have any temporal dependency
     72       and can be run in parallel, will be.  This parameter trades off
     73       time for space.  Values >> 1 use more memory but take less time,
     74       while smaller values use less memory but computations take longer.
     75     swap_memory: Transparently swap the tensors produced in forward inference
     76       but needed for back prop from GPU to CPU.  This allows training RNNs
     77       which would typically not fit on a single GPU, with very minimal (or no)
     78       performance penalty.
     79     time_major: The shape format of the `inputs` and `outputs` Tensors.
     80       If true, these `Tensors` must be shaped `[max_time, batch_size, depth]`.
     81       If false, these `Tensors` must be shaped `[batch_size, max_time, depth]`.
     82       Using `time_major = True` is a bit more efficient because it avoids
     83       transposes at the beginning and end of the RNN calculation.  However,
     84       most TensorFlow data is batch-major, so by default this function
     85       accepts input and emits output in batch-major form.
     86     scope: VariableScope for the created subgraph; defaults to "rnn".
     87 
     88   Returns:
     89     A pair (outputs, state) where:
     90 
     91     outputs: The RNN output `Tensor`.
     92 
     93       If time_major == False (default), this will be a `Tensor` shaped:
     94         `[batch_size, max_time, cell.output_size]`.
     95 
     96       If time_major == True, this will be a `Tensor` shaped:
     97         `[max_time, batch_size, cell.output_size]`.
     98 
     99       Note, if `cell.output_size` is a (possibly nested) tuple of integers
    100       or `TensorShape` objects, then `outputs` will be a tuple having the
    101       same structure as `cell.output_size`, containing Tensors having shapes
    102       corresponding to the shape data in `cell.output_size`.
    103 
    104     state: The final state.  If `cell.state_size` is an int, this
    105       will be shaped `[batch_size, cell.state_size]`.  If it is a
    106       `TensorShape`, this will be shaped `[batch_size] + cell.state_size`.
    107       If it is a (possibly nested) tuple of ints or `TensorShape`, this will
    108       be a tuple having the corresponding shapes. If cells are `LSTMCells`
    109       `state` will be a tuple containing a `LSTMStateTuple` for each cell.
    110 
    111   Raises:
    112     TypeError: If `cell` is not an instance of RNNCell.
    113     ValueError: If inputs is None or an empty list.
    114   """
    115   rnn_cell_impl.assert_like_rnncell("cell", cell)
    116 
    117   with vs.variable_scope(scope or "rnn") as varscope:
    118     # Create a new scope in which the caching device is either
    119     # determined by the parent scope, or is set to place the cached
    120     # Variable using the same placement as for the rest of the RNN.
    121     if _should_cache():
    122       if varscope.caching_device is None:
    123         varscope.set_caching_device(lambda op: op.device)
    124 
    125     # By default, time_major==False and inputs are batch-major: shaped
    126     #   [batch, time, depth]
    127     # For internal calculations, we transpose to [time, batch, depth]
    128     flat_input = nest.flatten(inputs)
    129 
    130     if not time_major:
    131       # (B,T,D) => (T,B,D)
    132       flat_input = [ops.convert_to_tensor(input_) for input_ in flat_input]
    133       flat_input = tuple(_transpose_batch_time(input_) for input_ in flat_input)
    134 
    135     parallel_iterations = parallel_iterations or 32
    136     if sequence_length is not None:
    137       sequence_length = math_ops.to_int32(sequence_length)
    138       if sequence_length.get_shape().ndims not in (None, 1):
    139         raise ValueError(
    140             "sequence_length must be a vector of length batch_size, "
    141             "but saw shape: %s" % sequence_length.get_shape())
    142       sequence_length = array_ops.identity(  # Just to find it in the graph.
    143           sequence_length, name="sequence_length")
    144 
    145     batch_size = _best_effort_input_batch_size(flat_input)
    146 
    147     if initial_state is not None:
    148       state = initial_state
    149     else:
    150       if not dtype:
    151         raise ValueError("If there is no initial_state, you must give a dtype.")
    152       if getattr(cell, "get_initial_state", None) is not None:
    153         state = cell.get_initial_state(
    154             inputs=None, batch_size=batch_size, dtype=dtype)
    155       else:
    156         state = cell.zero_state(batch_size, dtype)
    157 
    158     def _assert_has_shape(x, shape):
    159       x_shape = array_ops.shape(x)
    160       packed_shape = array_ops.stack(shape)
    161       return control_flow_ops.Assert(
    162           math_ops.reduce_all(math_ops.equal(x_shape, packed_shape)),
    163           ["Expected shape for Tensor %s is " % x.name,
    164            packed_shape, " but saw shape: ", x_shape])
    165 
    166     if not context.executing_eagerly() and sequence_length is not None:
    167       # Perform some shape validation
    168       with ops.control_dependencies(
    169           [_assert_has_shape(sequence_length, [batch_size])]):
    170         sequence_length = array_ops.identity(
    171             sequence_length, name="CheckSeqLen")
    172 
    173     inputs = nest.pack_sequence_as(structure=inputs, flat_sequence=flat_input)
    174 
    175     (outputs, final_state) = _dynamic_rnn_loop(
    176         cell,
    177         inputs,
    178         state,
    179         parallel_iterations=parallel_iterations,
    180         swap_memory=swap_memory,
    181         sequence_length=sequence_length,
    182         dtype=dtype)
    183 
    184     # Outputs of _dynamic_rnn_loop are always shaped [time, batch, depth].
    185     # If we are performing batch-major calculations, transpose output back
    186     # to shape [batch, time, depth]
    187     if not time_major:
    188       # (T,B,D) => (B,T,D)
    189       outputs = nest.map_structure(_transpose_batch_time, outputs)
    190 
    191     return (outputs, final_state)

    最后调用_dynamic_rnn_loop

      1 def _dynamic_rnn_loop(cell,
      2                       inputs,
      3                       initial_state,
      4                       parallel_iterations,
      5                       swap_memory,
      6                       sequence_length=None,
      7                       dtype=None):
      8   """Internal implementation of Dynamic RNN.
      9 
     10   Args:
     11     cell: An instance of RNNCell.
     12     inputs: A `Tensor` of shape [time, batch_size, input_size], or a nested
     13       tuple of such elements.
     14     initial_state: A `Tensor` of shape `[batch_size, state_size]`, or if
     15       `cell.state_size` is a tuple, then this should be a tuple of
     16       tensors having shapes `[batch_size, s] for s in cell.state_size`.
     17     parallel_iterations: Positive Python int.
     18     swap_memory: A Python boolean
     19     sequence_length: (optional) An `int32` `Tensor` of shape [batch_size].
     20     dtype: (optional) Expected dtype of output. If not specified, inferred from
     21       initial_state.
     22 
     23   Returns:
     24     Tuple `(final_outputs, final_state)`.
     25     final_outputs:
     26       A `Tensor` of shape `[time, batch_size, cell.output_size]`.  If
     27       `cell.output_size` is a (possibly nested) tuple of ints or `TensorShape`
     28       objects, then this returns a (possibly nested) tuple of Tensors matching
     29       the corresponding shapes.
     30     final_state:
     31       A `Tensor`, or possibly nested tuple of Tensors, matching in length
     32       and shapes to `initial_state`.
     33   Raises:
     34     ValueError: If the input depth cannot be inferred via shape inference
     35       from the inputs.
     36   """
     37   import pdb;pdb.set_trace()
     38   state = initial_state
     39   assert isinstance(parallel_iterations, int), "parallel_iterations must be int"
     40 
     41   state_size = cell.state_size#LSTMStateTuple(c=10, h=10)
     42 
     43   flat_input = nest.flatten(inputs)#list,~[0].shape=TensorShape([Dimension(3), Dimension(4), Dimension(6)])
     44   flat_output_size = nest.flatten(cell.output_size)#[10]
     45 
     46   # Construct an initial output
     47   input_shape = array_ops.shape(flat_input[0])#array([3, 4, 6]
     48   time_steps = input_shape[0]#3
     49   batch_size = _best_effort_input_batch_size(flat_input)#4
     50 
     51   inputs_got_shape = tuple(input_.get_shape().with_rank_at_least(3)
     52                            for input_ in flat_input)#(TensorShape([Dimension(3), Dimension(4), Dimension(6)]),)
     53 
     54   const_time_steps, const_batch_size = inputs_got_shape[0].as_list()[:2]#3,4
     55 
     56   for shape in inputs_got_shape:
     57     if not shape[2:].is_fully_defined():
     58       raise ValueError(
     59           "Input size (depth of inputs) must be accessible via shape inference,"
     60           " but saw value None.")
     61     got_time_steps = shape[0].value#3
     62     got_batch_size = shape[1].value#4
     63     if const_time_steps != got_time_steps:
     64       raise ValueError(
     65           "Time steps is not the same for all the elements in the input in a "
     66           "batch.")
     67     if const_batch_size != got_batch_size:
     68       raise ValueError(
     69           "Batch_size is not the same for all the elements in the input.")
     70 
     71   # Prepare dynamic conditional copying of state & output
     72   def _create_zero_arrays(size):
     73     size = _concat(batch_size, size)
     74     return array_ops.zeros(
     75         array_ops.stack(size), _infer_state_dtype(dtype, state))
     76 
     77   flat_zero_output = tuple(_create_zero_arrays(output)
     78                            for output in flat_output_size)#tuple,~[0].shape:TensorShape([Dimension(4), Dimension(10)])
     79   zero_output = nest.pack_sequence_as(structure=cell.output_size,
     80                                       flat_sequence=flat_zero_output)#TensorShape([Dimension(4), Dimension(10)])
     81 
     82   if sequence_length is not None:
     83     min_sequence_length = math_ops.reduce_min(sequence_length)#2
     84     max_sequence_length = math_ops.reduce_max(sequence_length)#3
     85   else:
     86     max_sequence_length = time_steps
     87 
     88   time = array_ops.constant(0, dtype=dtypes.int32, name="time")
     89 
     90   with ops.name_scope("dynamic_rnn") as scope:
     91     base_name = scope
     92 
     93   def _create_ta(name, element_shape, dtype):
     94     return tensor_array_ops.TensorArray(dtype=dtype,
     95                                         size=time_steps,
     96                                         element_shape=element_shape,
     97                                         tensor_array_name=base_name + name)
     98 
     99   in_graph_mode = not context.executing_eagerly()
    100   if in_graph_mode:
    101     output_ta = tuple(
    102         _create_ta(
    103             "output_%d" % i,
    104             element_shape=(tensor_shape.TensorShape([const_batch_size])
    105                            .concatenate(
    106                                _maybe_tensor_shape_from_tensor(out_size))),
    107             dtype=_infer_state_dtype(dtype, state))
    108         for i, out_size in enumerate(flat_output_size))
    109     input_ta = tuple(
    110         _create_ta(
    111             "input_%d" % i,
    112             element_shape=flat_input_i.shape[1:],
    113             dtype=flat_input_i.dtype)
    114         for i, flat_input_i in enumerate(flat_input))
    115     input_ta = tuple(ta.unstack(input_)
    116                      for ta, input_ in zip(input_ta, flat_input))
    117   else:
    118     output_ta = tuple([0 for _ in range(time_steps.numpy())]
    119                       for i in range(len(flat_output_size)))#([0, 0, 0],)
    120     input_ta = flat_input##list,~[0].shape=TensorShape([Dimension(3), Dimension(4), Dimension(6)])
    121 
    122   def _time_step(time, output_ta_t, state):
    123     """Take a time step of the dynamic RNN.
    124 
    125     Args:
    126       time: int32 scalar Tensor.
    127       output_ta_t: List of `TensorArray`s that represent the output.
    128       state: nested tuple of vector tensors that represent the state.
    129 
    130     Returns:
    131       The tuple (time + 1, output_ta_t with updated flow, new_state).
    132     """
    133     import pdb;pdb.set_trace()
    134     if in_graph_mode:
    135       input_t = tuple(ta.read(time) for ta in input_ta)
    136       # Restore some shape information
    137       for input_, shape in zip(input_t, inputs_got_shape):
    138         input_.set_shape(shape[1:])
    139     else:
    140       input_t = tuple(ta[time.numpy()] for ta in input_ta)3#TensorShape([Dimension(4), Dimension(6)])
    141 
    142     input_t = nest.pack_sequence_as(structure=inputs, flat_sequence=input_t)#TensorShape([Dimension(4), Dimension(6)])
    143     # Keras RNN cells only accept state as list, even if it's a single tensor.
    144     is_keras_rnn_cell = _is_keras_rnn_cell(cell)
    145     if is_keras_rnn_cell and not nest.is_sequence(state):
    146       state = [state]
    147     call_cell = lambda: cell(input_t, state)
    148 
    149     if sequence_length is not None:
    150       (output, new_state) = _rnn_step(
    151           time=time,
    152           sequence_length=sequence_length,
    153           min_sequence_length=min_sequence_length,
    154           max_sequence_length=max_sequence_length,
    155           zero_output=zero_output,
    156           state=state,
    157           call_cell=call_cell,
    158           state_size=state_size,
    159           skip_conditionals=True)
    160     else:
    161       (output, new_state) = call_cell()
    162 
    163     # Keras cells always wrap state as list, even if it's a single tensor.
    164     if is_keras_rnn_cell and len(new_state) == 1:
    165       new_state = new_state[0]
    166     # Pack state if using state tuples
    167     output = nest.flatten(output)
    168 
    169     if in_graph_mode:
    170       output_ta_t = tuple(
    171           ta.write(time, out) for ta, out in zip(output_ta_t, output))
    172     else:
    173       for ta, out in zip(output_ta_t, output):
    174         ta[time.numpy()] = out
    175 
    176     return (time + 1, output_ta_t, new_state)
    177 
    178   if in_graph_mode:
    179     # Make sure that we run at least 1 step, if necessary, to ensure
    180     # the TensorArrays pick up the dynamic shape.
    181     loop_bound = math_ops.minimum(
    182         time_steps, math_ops.maximum(1, max_sequence_length))
    183   else:
    184     # Using max_sequence_length isn't currently supported in the Eager branch.
    185     loop_bound = time_steps#3
    186 
    187   _, output_final_ta, final_state = control_flow_ops.while_loop(
    188       cond=lambda time, *_: time < loop_bound,
    189       body=_time_step,
    190       loop_vars=(time, output_ta, state),
    191       parallel_iterations=parallel_iterations,
    192       maximum_iterations=time_steps,
    193       swap_memory=swap_memory)
    194 
    195   # Unpack final output if not using output tuples.
    196   if in_graph_mode:
    197     final_outputs = tuple(ta.stack() for ta in output_final_ta)
    198     # Restore some shape information
    199     for output, output_size in zip(final_outputs, flat_output_size):
    200       shape = _concat(
    201           [const_time_steps, const_batch_size], output_size, static=True)
    202       output.set_shape(shape)
    203   else:
    204     final_outputs = output_final_ta
    205 
    206   final_outputs = nest.pack_sequence_as(
    207       structure=cell.output_size, flat_sequence=final_outputs)
    208   if not in_graph_mode:
    209     final_outputs = nest.map_structure_up_to(
    210         cell.output_size, lambda x: array_ops.stack(x, axis=0), final_outputs)
    211 
    212   return (final_outputs, final_state)

    可以看到dynamic_rnn主要是利用while_loop处理不同Batch长度不同的问题

    从上面82-86行看出,如果不给sequence_length参数,sequence_length=time_step=input.shape[0],当给定参数sequence_length时,调用_rnn_step函数,对超出长度的部分output设0,这一点在下面代码60,70行实现

      1 def _rnn_step(
      2     time, sequence_length, min_sequence_length, max_sequence_length,
      3     zero_output, state, call_cell, state_size, skip_conditionals=False):
      4   """Calculate one step of a dynamic RNN minibatch.
      5 
      6   Returns an (output, state) pair conditioned on `sequence_length`.
      7   When skip_conditionals=False, the pseudocode is something like:
      8 
      9   if t >= max_sequence_length:
     10     return (zero_output, state)
     11   if t < min_sequence_length:
     12     return call_cell()
     13 
     14   # Selectively output zeros or output, old state or new state depending
     15   # on whether we've finished calculating each row.
     16   new_output, new_state = call_cell()
     17   final_output = np.vstack([
     18     zero_output if time >= sequence_length[r] else new_output_r
     19     for r, new_output_r in enumerate(new_output)
     20   ])
     21   final_state = np.vstack([
     22     state[r] if time >= sequence_length[r] else new_state_r
     23     for r, new_state_r in enumerate(new_state)
     24   ])
     25   return (final_output, final_state)
     26 
     27   Args:
     28     time: int32 `Tensor` scalar.
     29     sequence_length: int32 `Tensor` vector of size [batch_size].
     30     min_sequence_length: int32 `Tensor` scalar, min of sequence_length.
     31     max_sequence_length: int32 `Tensor` scalar, max of sequence_length.
     32     zero_output: `Tensor` vector of shape [output_size].
     33     state: Either a single `Tensor` matrix of shape `[batch_size, state_size]`,
     34       or a list/tuple of such tensors.
     35     call_cell: lambda returning tuple of (new_output, new_state) where
     36       new_output is a `Tensor` matrix of shape `[batch_size, output_size]`.
     37       new_state is a `Tensor` matrix of shape `[batch_size, state_size]`.
     38     state_size: The `cell.state_size` associated with the state.
     39     skip_conditionals: Python bool, whether to skip using the conditional
     40       calculations.  This is useful for `dynamic_rnn`, where the input tensor
     41       matches `max_sequence_length`, and using conditionals just slows
     42       everything down.
     43 
     44   Returns:
     45     A tuple of (`final_output`, `final_state`) as given by the pseudocode above:
     46       final_output is a `Tensor` matrix of shape [batch_size, output_size]
     47       final_state is either a single `Tensor` matrix, or a tuple of such
     48         matrices (matching length and shapes of input `state`).
     49 
     50   Raises:
     51     ValueError: If the cell returns a state tuple whose length does not match
     52       that returned by `state_size`.
     53   """
     54   import pdb;pdb.set_trace()
     55   # Convert state to a list for ease of use
     56   flat_state = nest.flatten(state)#[c,h],shape=[4,10]
     57   flat_zero_output = nest.flatten(zero_output)#list,~[0].shape:TensorShape([Dimension(4), Dimension(10)])
     58 
     59   # Vector describing which batch entries are finished.
     60   copy_cond = time >= sequence_length#step1:array([False, False, False, False])
     61 
     62   def _copy_one_through(output, new_output):
     63     # TensorArray and scalar get passed through.
     64     if isinstance(output, tensor_array_ops.TensorArray):
     65       return new_output
     66     if output.shape.ndims == 0:
     67       return new_output
     68     # Otherwise propagate the old or the new value.
     69     with ops.colocate_with(new_output):
     70       return array_ops.where(copy_cond, output, new_output)#多余的取0
     71 
     72   def _copy_some_through(flat_new_output, flat_new_state):
     73     # Use broadcasting select to determine which values should get
     74     # the previous state & zero output, and which values should get
     75     # a calculated state & output.
     76     flat_new_output = [
     77         _copy_one_through(zero_output, new_output)
     78         for zero_output, new_output in zip(flat_zero_output, flat_new_output)]
     79     flat_new_state = [
     80         _copy_one_through(state, new_state)
     81         for state, new_state in zip(flat_state, flat_new_state)]
     82     return flat_new_output + flat_new_state
     83 
     84   def _maybe_copy_some_through():
     85     """Run RNN step.  Pass through either no or some past state."""
     86     new_output, new_state = call_cell()
     87 
     88     nest.assert_same_structure(state, new_state)
     89 
     90     flat_new_state = nest.flatten(new_state)
     91     flat_new_output = nest.flatten(new_output)
     92     return control_flow_ops.cond(
     93         # if t < min_seq_len: calculate and return everything
     94         time < min_sequence_length, lambda: flat_new_output + flat_new_state,
     95         # else copy some of it through
     96         lambda: _copy_some_through(flat_new_output, flat_new_state))
     97 
     98   # TODO(ebrevdo): skipping these conditionals may cause a slowdown,
     99   # but benefits from removing cond() and its gradient.  We should
    100   # profile with and without this switch here.
    101   if skip_conditionals:
    102     # Instead of using conditionals, perform the selective copy at all time
    103     # steps.  This is faster when max_seq_len is equal to the number of unrolls
    104     # (which is typical for dynamic_rnn).
    105     new_output, new_state = call_cell()
    106     nest.assert_same_structure(state, new_state)
    107     new_state = nest.flatten(new_state)#[c,h],shape=(4, 10)
    108     new_output = nest.flatten(new_output)#shape=(4, 10)
    109     final_output_and_state = _copy_some_through(new_output, new_state)
    110   else:
    111     empty_update = lambda: flat_zero_output + flat_state
    112     final_output_and_state = control_flow_ops.cond(
    113         # if t >= max_seq_len: copy all state through, output zeros
    114         time >= max_sequence_length, empty_update,
    115         # otherwise calculation is required: copy some or all of it through
    116         _maybe_copy_some_through)
    117 
    118   if len(final_output_and_state) != len(flat_zero_output) + len(flat_state):
    119     raise ValueError("Internal error: state and output were not concatenated "
    120                      "correctly.")
    121   final_output = final_output_and_state[:len(flat_zero_output)]
    122   final_state = final_output_and_state[len(flat_zero_output):]
    123 
    124   for output, flat_output in zip(final_output, flat_zero_output):
    125     output.set_shape(flat_output.get_shape())
    126   for substate, flat_substate in zip(final_state, flat_state):
    127     if not isinstance(substate, tensor_array_ops.TensorArray):
    128       substate.set_shape(flat_substate.get_shape())
    129 
    130   final_output = nest.pack_sequence_as(
    131       structure=zero_output, flat_sequence=final_output)
    132   final_state = nest.pack_sequence_as(
    133       structure=state, flat_sequence=final_state)
    134 
    135   return final_output, final_state
  • 相关阅读:
    CopyOnWriteArrayList设计思路与源码分析
    点击页面按钮以excel保存到本地
    上传图片
    关于重复点击的
    去首尾空格还有换行问题//把数字换位大写字母//向后台传输数据
    判断输入的时间与当前的时间(判断时间是今天还是以前的)
    前端的一些小技巧
    git 操作大全
    移动web开发常见问题解决方案
    响应式布局
  • 原文地址:https://www.cnblogs.com/buyizhiyou/p/9883182.html
Copyright © 2011-2022 走看看