zoukankan      html  css  js  c++  java
  • pytorch nn.LSTM()参数详解

    输入数据格式:
    input(seq_len, batch, input_size)
    h0(num_layers * num_directions, batch, hidden_size)
    c0(num_layers * num_directions, batch, hidden_size)

    输出数据格式:
    output(seq_len, batch, hidden_size * num_directions)
    hn(num_layers * num_directions, batch, hidden_size)
    cn(num_layers * num_directions, batch, hidden_size)

    import torch
    import torch.nn as nn
    from torch.autograd import Variable

    #构建网络模型---输入矩阵特征数input_size、输出矩阵特征数hidden_size、层数num_layers
    inputs = torch.randn(5,3,10) ->(seq_len,batch_size,input_size)
    rnn = nn.LSTM(10,20,2) -> (input_size,hidden_size,num_layers)
    h0 = torch.randn(2,3,20) ->(num_layers* 1,batch_size,hidden_size)
    c0 = torch.randn(2,3,20) ->(num_layers*1,batch_size,hidden_size)
    num_directions=1 因为是单向LSTM
    '''
    Outputs: output, (h_n, c_n)
    '''
    output,(hn,cn) = rnn(inputs,(h0,c0))
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    batch_first: 输入输出的第一维是否为 batch_size,默认值 False。因为 Torch 中,人们习惯使用Torch中带有的dataset,dataloader向神经网络模型连续输入数据,这里面就有一个 batch_size 的参数,表示一次输入多少个数据。 在 LSTM 模型中,输入数据必须是一批数据,为了区分LSTM中的批量数据和dataloader中的批量数据是否相同意义,LSTM 模型就通过这个参数的设定来区分。 如果是相同意义的,就设置为True,如果不同意义的,设置为False。 torch.LSTM 中 batch_size 维度默认是放在第二维度,故此参数设置可以将 batch_size 放在第一维度。如:input 默认是(4,1,5),中间的 1 是 batch_size,指定batch_first=True后就是(1,4,5)。所以,如果你的输入数据是二维数据的话,就应该将 batch_first 设置为True;

    inputs = torch.randn(5,3,10) :seq_len=5,bitch_size=3,input_size=10
    我的理解:有3个句子,每个句子5个单词,每个单词用10维的向量表示;而句子的长度是不一样的,所以seq_len可长可短,这也是LSTM可以解决长短序列的特殊之处。只有seq_len这一参数是可变的。
    关于hn和cn一些参数的详解看这里
    而在遇到文本长度不一致的情况下,将数据输入到模型前的特征工程会将同一个batch内的文本进行padding使其长度对齐。但是对齐的数据在单向LSTM甚至双向LSTM的时候有一个问题,LSTM会处理很多无意义的填充字符,这样会对模型有一定的偏差,这时候就需要用到函数torch.nn.utils.rnn.pack_padded_sequence()以及torch.nn.utils.rnn.pad_packed_sequence()
    详情解释看这里

    BiLSTM
    BILSTM是双向LSTM;将前向的LSTM与后向的LSTM结合成LSTM。视图举例如下:


    ​​​​​​​​​​​​LSTM结构推导:


    更详细公式推导https://blog.csdn.net/songhk0209/article/details/71134698

    GRU公式推导:(网上的图看着有点费劲,就自己画了个数据流图)


    ---------------------
    作者:向阳争渡
    来源:CSDN
    原文:https://blog.csdn.net/yangyang_yangqi/article/details/84585998
    版权声明:本文为博主原创文章,转载请附上博文链接!

  • 相关阅读:
    淡入淡出js
    Comparable和Comparator的区别
    mybatis的动态sql详解
    mybatis动态sql之foreach
    mybatis的动态sql中collection与assoction
    Mybatis中#与$区别
    转JSONObject put,accumulate,element的区别
    Spring配置,JDBC数据源及事务
    销毁session
    IIS express 7.5 设置默认文档
  • 原文地址:https://www.cnblogs.com/jfdwd/p/11184846.html
Copyright © 2011-2022 走看看