zoukankan      html  css  js  c++  java
  • RNN神经网络层的输出格式和形状

    tf.keras.layers.RNN(
        cell, return_sequences=False, return_state=False, go_backwards=False,
        stateful=False, unroll=False, time_major=False, **kwargs
    )
    

    RNN的输出非常tricky,官方文档要细读。
    If return_state: a list of tensors. The first tensor is the output. The remaining tensors are the last states, each with shape [batch_size, state_size], where state_size could be a high dimension tensor shape.
    If return_sequences: N-D tensor with shape [batch_size, timesteps, output_size], where output_size could be a high dimension tensor shape, or [timesteps, batch_size, output_size] when time_major is True.
    Else, N-D tensor with shape [batch_size, output_size], where output_size could be a high dimension tensor shape.

    如果return_state=Truereturn_sequences=True,输出的一个list,第一个元素是output,shape [batch_size, timesteps, output_size],第二个元素是latest states,注意是latest,最后时刻的states,而不是每个时刻的states,所以shape [batch_size, state_size]

    out = keras.layers.RNN(LNSimpleRNNCell(20), return_sequences=True, return_state=True,
                         input_shape=[None, 1])(X_train)
    
    X_train.shape # (7000, 50, 1)
    out[0].shape # ([7000, 50, 20])
    out[1].shape # ([7000, 20])
    
  • 相关阅读:
    二分查找
    50道经典的JAVA编程题(46-50)
    50道经典的JAVA编程题(41-45)
    50道经典的JAVA编程题(36-40)
    50道经典的JAVA编程题(31-35)
    今天考试的JAVA编程题
    50道经典的JAVA编程题(26-30)
    50道经典的JAVA编程题(21-25)
    50道经典的JAVA编程题 (16-20)
    50道经典的JAVA编程题(目录)
  • 原文地址:https://www.cnblogs.com/yaos/p/14014148.html
Copyright © 2011-2022 走看看