zoukankan      html  css  js  c++  java
  • [PyTorch] rnn,lstm,gru中输入输出维度

    本文中的RNN泛指LSTM,GRU等等
    CNN中和RNNbatchSize的默认位置是不同的。

    • CNN中:batchsize的位置是position 0.
    • RNN中:batchsize的位置是position 1.

    在RNN中输入数据格式:

    对于最简单的RNN,我们可以使用两种方式来调用,torch.nn.RNNCell(),它只接受序列中的单步输入,必须显式的传入隐藏状态torch.nn.RNN()可以接受一个序列的输入,默认会传入一个全0的隐藏状态,也可以自己申明隐藏状态传入。

    1. 输入大小是三维tensor[seq_len,batch_size,input_dim]
    • input_dim是输入的维度,比如是128
    • batch_size是一次往RNN输入句子的数目,比如是5
    • seq_len是一个句子的最大长度,比如15
      所以千万注意,RNN输入的是序列,一次把批次的所有句子都输入了,得到的ouptuthidden都是这个批次的所有的输出和隐藏状态,维度也是三维。
      **可以理解为现在一共有batch_size个独立的RNN组件,RNN的输入维度是input_dim,总共输入seq_len个时间步,则每个时间步输入到这个整个RNN模块的维度是[batch_size,input_dim]
    # 构造RNN网络,x的维度5,隐层的维度10,网络的层数2
    rnn_seq = nn.RNN(5, 10,2)  
    # 构造一个输入序列,句长为 6,batch 是 3, 每个单词使用长度是 5的向量表示
    x = torch.randn(6, 3, 5)
    #out,ht = rnn_seq(x,h0) 
    out,ht = rnn_seq(x) #h0可以指定或者不指定
    

    问题1:这里outhtsize是多少呢?
    回答out:6 * 3 * 10, ht: 2 * 3 * 10,out的输出维度[seq_len,batch_size,output_dim],ht的维度[num_layers * num_directions, batch, hidden_size],如果是单向单层的RNN那么一个句子只有一个hidden
    问题2out[-1]ht[-1]是否相等?
    回答:相等,隐藏单元就是输出的最后一个单元,可以想象,每个的输出其实就是那个时间步的隐藏单元

    1. RNN的其他参数
    RNN(input_dim ,hidden_dim ,num_layers ,…)
    – input_dim 表示输入的特征维度
    – hidden_dim 表示输出的特征维度,如果没有特殊变化,相当于out
    – num_layers 表示网络的层数
    – nonlinearity 表示选用的非线性激活函数,默认是 ‘tanh’
    – bias 表示是否使用偏置,默认使用
    – batch_first 表示输入数据的形式,默认是 False,就是这样形式,(seq, batch, feature),也就是将序列长度放在第一位,batch 放在第二位
    – dropout 表示是否在输出层应用 dropout
    – bidirectional 表示是否使用双向的 rnn,默认是 False
    
     
    向RNN中输入的tensor的形状

    LSTM的输出多了一个memory单元

    # 输入维度 50,隐层100维,两层
    lstm_seq = nn.LSTM(50, 100, num_layers=2)
    # 输入序列seq= 10,batch =3,输入维度=50
    lstm_input = torch.randn(10, 3, 50)
    out, (h, c) = lstm_seq(lstm_input) # 使用默认的全 0 隐藏状态
    

    问题1out(h,c)的size各是多少?
    回答out:(10 * 3 * 100),(h,c):都是(2 * 3 * 100)
    问题2out[-1,:,:]h[-1,:,:]相等吗?
    回答: 相等

    GRU比较像传统的RNN

    gru_seq = nn.GRU(10, 20,2) # x_dim,h_dim,layer_num
    gru_input = torch.randn(3, 32, 10) # seq,batch,x_dim
    out, h = gru_seq(gru_input)


    作者:VanJordan
    链接:https://www.jianshu.com/p/b942e65cb0a3
    来源:简书
    简书著作权归作者所有,任何形式的转载都请联系作者获得授权并注明出处。
  • 相关阅读:
    小程序前端直传阿里云oss的一些记录
    小程序的两种分页做法(后端返回分页及总页数字段与否)
    小程序模糊搜索(词汇联想)
    小程序自定义组件的两种方式
    js对数据的一些处理方法(待完善)
    小程序关于登录授权回跳页面的两个问题记录
    小程序登录的一些简单步骤
    关于js的方括号[]属性赋值的一些记录
    js状态转化的简单写法
    微信企业号开发node版
  • 原文地址:https://www.cnblogs.com/jfdwd/p/11069096.html
Copyright © 2011-2022 走看看