zoukankan      html  css  js  c++  java
  • nn.LSTM输入、输出、参数及pad

    1.LSTM的三个输出output, hidden, cell,分别表示什么意思?

    https://blog.csdn.net/wangwangstone/article/details/90296461 这里最后的代码中能搞明白。

    import torch
    import torch.nn as nn             # 神经网络模块
    rnn = nn.LSTM(10, 20, 2) 
    # 输入数据x的向量维数10, 设定lstm隐藏层的特征维度20, 此model用2个lstm层。如果是1,可以省略,默认为1)
    input = torch.randn(5, 3, 10)
    # 输入的input为,序列长度seq_len=5, 每次取的minibatch大小,batch_size=3, 
    # 数据向量维数=10(仍然为x的维度)。每次运行时取3个含有5个字的句子(且句子中每个字的维度为10进行运行)  # 初始化的隐藏元和记忆元,通常它们的维度是一样的 # 2个LSTM层,batch_size=3, 隐藏层的特征维度20 h0 = torch.randn(2, 3, 20) c0 = torch.randn(2, 3, 20) # 这里有2层lstm,output是最后一层lstm的每个词向量对应隐藏层的输出,其与层数无关,只与序列长度相关 # hn,cn是所有层最后一个隐藏元和记忆元的输出,和层数、隐层大小有关。 output, (hn, cn) = rnn(input, (h0, c0)) ##模型的三个输入与三个输出。三个输入与输出的理解见上三输入,三输出 print(output.size(),hn.size(),cn.size()) #输出:torch.Size([5, 3, 20]) torch.Size([2, 3, 20]) torch.Size([2, 3, 20])

     输入数据格式: (三个输入)

    input(seq_len, batch, input_size)

    h_0(num_layers * num_directions, batch, hidden_size)

    c_0(num_layers * num_directions, batch, hidden_size)

    输出数据格式:

    output(seq_len, batch, hidden_size * num_directions)

    h_n(num_layers * num_directions, batch, hidden_size)

    c_n(num_layers * num_directions, batch, hidden_size)

    设置batch_first=True:

    import torch
    import torch.nn as nn 
    rnn = nn.LSTM(10, 20, 2,batch_first=True)
    input = torch.randn(5, 3, 10)
    h0 = torch.randn(2, 5, 20)
    c0 = torch.randn(2, 5, 20)
    output, (hn, cn) = rnn(input, (h0, c0))
    print(output.size(),hn.size(),cn.size())
    
    torch.Size([5, 3, 20]) torch.Size([2, 5, 20]) torch.Size([2, 5, 20])
    #可以看到是shape和输入保持一致的
    #如果开始设置了批次在前,那么输出时也会批次在前。

    2.pack_padded_sequence /pad_packed_sequence

    https://zhuanlan.zhihu.com/p/34418001?edition=yidianzixun&utm_source=yidianzixun&yidian_docid=0IVwLf60

    https://www.cnblogs.com/lindaxin/p/8052043.html

    这个还蛮有意思的,通过这个代码我明白了:

    import torch
    import torch.nn as nn
    from torch.autograd import Variable
    from torch.nn import utils as nn_utils
    batch_size = 2
    max_length = 3
    hidden_size = 2
    n_layers =1
     
    tensor_in = torch.FloatTensor([[1, 2, 3], [1, 0, 0]]).resize_(2,3,1)
    tensor_in = Variable( tensor_in ) #[batch, seq, feature], [2, 3, 1]
    seq_lengths = [3,1] # list of integers holding information about the batch size at each sequence step
     
    # pack it
    pack = nn_utils.rnn.pack_padded_sequence(tensor_in, seq_lengths, batch_first=True)
     
    # initialize
    rnn = nn.RNN(1, hidden_size, n_layers, batch_first=True)
    h0 = Variable(torch.randn(n_layers, batch_size, hidden_size))
     
    #forward
    out, _ = rnn(pack, h0)
     
    # unpack
    unpacked = nn_utils.rnn.pad_packed_sequence(out)
    print('111',unpacked)

     输出:

    111 (tensor([[[ 0.0217, -0.4686],
             [-0.0132, -0.4441]],
    
            [[-0.2905, -0.7984],
             [ 0.0000,  0.0000]],
    
            [[-0.6398, -0.9216],
             [ 0.0000,  0.0000]]], grad_fn=<CopySlices>), tensor([3, 1]))
    
    #中间结果:
    >>> tensor_in
    tensor([[[1.],
             [2.],
             [3.]],
    
            [[1.],
             [0.],
             [0.]]])
    #输入的是已经添加了batchsize的,是一个三维的
    
    >>> pack
    PackedSequence(data=tensor([[1.],
            [1.],
            [2.],
            [3.]]), batch_sizes=tensor([2, 1, 1]))
    #经过压缩之后就可以发现很有意思,它是按照batch压缩的,压缩成一个长序列,然后还有一个batch_size的参数来控制读取。
    
    >>> h0
    tensor([[[ 0.2472,  0.2233],
             [-0.0616,  0.2127]]])
    #初始化的隐状态
    
    >>> out
    PackedSequence(data=tensor([[ 0.0217, -0.4686],
            [-0.0132, -0.4441],
            [-0.2905, -0.7984],
            [-0.6398, -0.9216]], grad_fn=<CatBackward>), batch_sizes=tensor([2, 1, 1]))
    #这个是未经过pad之前的,可以发现是和pack的结构类似,
    #为什么变成了二维的呢,是因为hidden_size是二维的

    总之是为了lstm的特殊输入,需要先pack,pack之后lstm的output再pad输出。

  • 相关阅读:
    Python 自动化测试实战训练营,由浅入深,从小白到测试高手!
    接口测试 Mock 实战(二) | 结合 jq 完成批量化的手工 Mock
    从文科生转行测试,再到大厂测试开发工程师,我是如何做到的?
    严正声明|严厉打击盗版侵权、非法销售「霍格沃兹测试学院」课程的违法行为
    「金羽毛」有奖征文 | 记录测试开发技术进阶之路的点滴
    测试工程师职业发展漫谈
    Workshop 深圳站|实战+源码架构剖析带你揭开Appium的神秘面纱
    那些难改的 Bug,最后都怎样了?
    BAT大厂都在用的Docker。学会这三招,面试、工作轻松hold住
    2021 开年学习送福利,助力测试进阶提升!
  • 原文地址:https://www.cnblogs.com/BlueBlueSea/p/13723560.html
Copyright © 2011-2022 走看看