zoukankan      html  css  js  c++  java
  • [PyTorch] rnn,lstm,gru中输入输出维度

    本文中的RNN泛指LSTM,GRU等等
    CNN中和RNNbatchSize的默认位置是不同的。

    • CNN中:batchsize的位置是position 0.
    • RNN中:batchsize的位置是position 1.

    在RNN中输入数据格式:

    对于最简单的RNN,我们可以使用两种方式来调用,torch.nn.RNNCell(),它只接受序列中的单步输入,必须显式的传入隐藏状态torch.nn.RNN()可以接受一个序列的输入,默认会传入一个全0的隐藏状态,也可以自己申明隐藏状态传入。

    1. 输入大小是三维tensor[seq_len,batch_size,input_dim]
    • input_dim是输入的维度,比如是128
    • batch_size是一次往RNN输入句子的数目,比如是5
    • seq_len是一个句子的最大长度,比如15
      所以千万注意,RNN输入的是序列,一次把批次的所有句子都输入了,得到的ouptuthidden都是这个批次的所有的输出和隐藏状态,维度也是三维。
      **可以理解为现在一共有batch_size个独立的RNN组件,RNN的输入维度是input_dim,总共输入seq_len个时间步,则每个时间步输入到这个整个RNN模块的维度是[batch_size,input_dim]
    # 构造RNN网络,x的维度5,隐层的维度10,网络的层数2
    rnn_seq = nn.RNN(5, 10,2)  
    # 构造一个输入序列,句长为 6,batch 是 3, 每个单词使用长度是 5的向量表示
    x = torch.randn(6, 3, 5)
    #out,ht = rnn_seq(x,h0) 
    out,ht = rnn_seq(x) #h0可以指定或者不指定
    

    问题1:这里outhtsize是多少呢?
    回答out:6 * 3 * 10, ht: 2 * 3 * 10,out的输出维度[seq_len,batch_size,output_dim],ht的维度[num_layers * num_directions, batch, hidden_size],如果是单向单层的RNN那么一个句子只有一个hidden
    问题2out[-1]ht[-1]是否相等?
    回答:相等,隐藏单元就是输出的最后一个单元,可以想象,每个的输出其实就是那个时间步的隐藏单元

    1. RNN的其他参数
    RNN(input_dim ,hidden_dim ,num_layers ,…)
    – input_dim 表示输入的特征维度
    – hidden_dim 表示输出的特征维度,如果没有特殊变化,相当于out
    – num_layers 表示网络的层数
    – nonlinearity 表示选用的非线性激活函数,默认是 ‘tanh’
    – bias 表示是否使用偏置,默认使用
    – batch_first 表示输入数据的形式,默认是 False,就是这样形式,(seq, batch, feature),也就是将序列长度放在第一位,batch 放在第二位
    – dropout 表示是否在输出层应用 dropout
    – bidirectional 表示是否使用双向的 rnn,默认是 False
    
     
    向RNN中输入的tensor的形状

    LSTM的输出多了一个memory单元

    # 输入维度 50,隐层100维,两层
    lstm_seq = nn.LSTM(50, 100, num_layers=2)
    # 输入序列seq= 10,batch =3,输入维度=50
    lstm_input = torch.randn(10, 3, 50)
    out, (h, c) = lstm_seq(lstm_input) # 使用默认的全 0 隐藏状态
    

    问题1out(h,c)的size各是多少?
    回答out:(10 * 3 * 100),(h,c):都是(2 * 3 * 100)
    问题2out[-1,:,:]h[-1,:,:]相等吗?
    回答: 相等

    GRU比较像传统的RNN

    gru_seq = nn.GRU(10, 20,2) # x_dim,h_dim,layer_num
    gru_input = torch.randn(3, 32, 10) # seq,batch,x_dim
    out, h = gru_seq(gru_input)


    作者:VanJordan
    链接:https://www.jianshu.com/p/b942e65cb0a3
    来源:简书
    简书著作权归作者所有,任何形式的转载都请联系作者获得授权并注明出处。
  • 相关阅读:
    利用相关的Aware接口
    java 值传递和引用传递。
    权限控制框架Spring Security 和Shiro 的总结
    优秀代码养成
    Servlet 基础知识
    leetcode 501. Find Mode in Binary Search Tree
    leetcode 530. Minimum Absolute Difference in BST
    leetcode 543. Diameter of Binary Tree
    leetcode 551. Student Attendance Record I
    leetcode 563. Binary Tree Tilt
  • 原文地址:https://www.cnblogs.com/jfdwd/p/11069096.html
Copyright © 2011-2022 走看看