zoukankan      html  css  js  c++  java
  • tf.nn.dynamic.rnn的用法以及输出outputs、state介绍

    通过RNNCell建立一个RNN网络,要输入的参数:

    tf.nn.dynamic_rnn(
        cell,
        inputs,
        sequence_length=None,
        initial_state=None,
        dtype=None,
        parallel_iterations=None,
        swap_memory=False,
        time_major=False,
        scope=None
    )

    例子,使用RNNCell,

    # create a BasicRNNCell
    rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size)
    
    # 'outputs' is a tensor of shape [batch_size, max_time, cell_state_size]
    
    # defining initial state
    initial_state = rnn_cell.zero_state(batch_size, dtype=tf.float32)
    
    # 'state' is a tensor of shape [batch_size, cell_state_size]
    outputs, state = tf.nn.dynamic_rnn(rnn_cell, input_data,
                                       initial_state=initial_state,
                                       dtype=tf.float32)

    使用两层LSTM:

    # create 2 LSTMCells
    rnn_layers = [tf.nn.rnn_cell.LSTMCell(size) for size in [128, 256]]
    
    # create a RNN cell composed sequentially of a number of RNNCells
    multi_rnn_cell = tf.nn.rnn_cell.MultiRNNCell(rnn_layers)
    
    # 'outputs' is a tensor of shape [batch_size, max_time, 256]
    # 'state' is a N-tuple where N is the number of LSTMCells containing a
    # tf.contrib.rnn.LSTMStateTuple for each cell
    outputs, state = tf.nn.dynamic_rnn(cell=multi_rnn_cell,
                                       inputs=data,
                                       dtype=tf.float32)

    输出是outputs与state:

    • outputs: The RNN output Tensor.

      If time_major == False (default), this will be a Tensor shaped: [batch_size, max_time, cell.output_size].

      If time_major == True, this will be a Tensor shaped: [max_time, batch_size, cell.output_size].

      Note, if cell.output_size is a (possibly nested) tuple of integers or TensorShape objects, then outputs will be a tuple having the same structure as cell.output_size, containing Tensors having shapes corresponding to the shape data in cell.output_size.

    • state: The final state. If cell.state_size is an int, this will be shaped [batch_size, cell.state_size]. If it is a TensorShape, this will be shaped [batch_size] + cell.state_size. If it is a (possibly nested) tuple of ints or TensorShape, this will be a tuple having the corresponding shapes. If cells are LSTMCells state will be a tuple containing a LSTMStateTuple for each cell.

    outputs就是输出的张量结果,state是最后一个cell的输出结果,也就是最后一个时刻的状态,是一个tensor。

    outputs里面,包含了所有时刻的输出 H;

    state里面,包含了最后一个时刻的输出 H 和 C。

  • 相关阅读:
    单链表反转非递归
    Java中boolean类型到底占用多少个字节
    多线程,计算List<Integer>
    es聚合操作
    字符串压缩
    dart effective-风格和文档
    dart effective-用法
    node 安装
    Rabbitmq 报错信息
    rabbitmq 工作模式
  • 原文地址:https://www.cnblogs.com/liangzp/p/10130242.html
Copyright © 2011-2022 走看看