通过RNNCell建立一个RNN网络,要输入的参数:
tf.nn.dynamic_rnn( cell, inputs, sequence_length=None, initial_state=None, dtype=None, parallel_iterations=None, swap_memory=False, time_major=False, scope=None )
例子,使用RNNCell,
# create a BasicRNNCell rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) # 'outputs' is a tensor of shape [batch_size, max_time, cell_state_size] # defining initial state initial_state = rnn_cell.zero_state(batch_size, dtype=tf.float32) # 'state' is a tensor of shape [batch_size, cell_state_size] outputs, state = tf.nn.dynamic_rnn(rnn_cell, input_data, initial_state=initial_state, dtype=tf.float32)
使用两层LSTM:
# create 2 LSTMCells rnn_layers = [tf.nn.rnn_cell.LSTMCell(size) for size in [128, 256]] # create a RNN cell composed sequentially of a number of RNNCells multi_rnn_cell = tf.nn.rnn_cell.MultiRNNCell(rnn_layers) # 'outputs' is a tensor of shape [batch_size, max_time, 256] # 'state' is a N-tuple where N is the number of LSTMCells containing a # tf.contrib.rnn.LSTMStateTuple for each cell outputs, state = tf.nn.dynamic_rnn(cell=multi_rnn_cell, inputs=data, dtype=tf.float32)
输出是outputs与state:
-
outputs: The RNN outputTensor.If time_major == False (default), this will be a
Tensorshaped:[batch_size, max_time, cell.output_size].If time_major == True, this will be a
Tensorshaped:[max_time, batch_size, cell.output_size].Note, if
cell.output_sizeis a (possibly nested) tuple of integers orTensorShapeobjects, thenoutputswill be a tuple having the same structure ascell.output_size, containing Tensors having shapes corresponding to the shape data incell.output_size. -
state: The final state. Ifcell.state_sizeis an int, this will be shaped[batch_size, cell.state_size]. If it is aTensorShape, this will be shaped[batch_size] + cell.state_size. If it is a (possibly nested) tuple of ints orTensorShape, this will be a tuple having the corresponding shapes. If cells areLSTMCellsstatewill be a tuple containing aLSTMStateTuplefor each cell.
outputs就是输出的张量结果,state是最后一个cell的输出结果,也就是最后一个时刻的状态,是一个tensor。
outputs里面,包含了所有时刻的输出 H;
state里面,包含了最后一个时刻的输出 H 和 C。