通过RNNCell建立一个RNN网络,要输入的参数:
tf.nn.dynamic_rnn( cell, inputs, sequence_length=None, initial_state=None, dtype=None, parallel_iterations=None, swap_memory=False, time_major=False, scope=None )
例子,使用RNNCell,
# create a BasicRNNCell rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) # 'outputs' is a tensor of shape [batch_size, max_time, cell_state_size] # defining initial state initial_state = rnn_cell.zero_state(batch_size, dtype=tf.float32) # 'state' is a tensor of shape [batch_size, cell_state_size] outputs, state = tf.nn.dynamic_rnn(rnn_cell, input_data, initial_state=initial_state, dtype=tf.float32)
使用两层LSTM:
# create 2 LSTMCells rnn_layers = [tf.nn.rnn_cell.LSTMCell(size) for size in [128, 256]] # create a RNN cell composed sequentially of a number of RNNCells multi_rnn_cell = tf.nn.rnn_cell.MultiRNNCell(rnn_layers) # 'outputs' is a tensor of shape [batch_size, max_time, 256] # 'state' is a N-tuple where N is the number of LSTMCells containing a # tf.contrib.rnn.LSTMStateTuple for each cell outputs, state = tf.nn.dynamic_rnn(cell=multi_rnn_cell, inputs=data, dtype=tf.float32)
输出是outputs与state:
-
outputs
: The RNN outputTensor
.If time_major == False (default), this will be a
Tensor
shaped:[batch_size, max_time, cell.output_size]
.If time_major == True, this will be a
Tensor
shaped:[max_time, batch_size, cell.output_size]
.Note, if
cell.output_size
is a (possibly nested) tuple of integers orTensorShape
objects, thenoutputs
will be a tuple having the same structure ascell.output_size
, containing Tensors having shapes corresponding to the shape data incell.output_size
. -
state
: The final state. Ifcell.state_size
is an int, this will be shaped[batch_size, cell.state_size]
. If it is aTensorShape
, this will be shaped[batch_size] + cell.state_size
. If it is a (possibly nested) tuple of ints orTensorShape
, this will be a tuple having the corresponding shapes. If cells areLSTMCells
state
will be a tuple containing aLSTMStateTuple
for each cell.
outputs就是输出的张量结果,state是最后一个cell的输出结果,也就是最后一个时刻的状态,是一个tensor。
outputs里面,包含了所有时刻的输出 H;
state里面,包含了最后一个时刻的输出 H 和 C。