zoukankan      html  css  js  c++  java
  • [转] Torch中实现mini-batch RNN

    工作中需要把一个SGD的LSTM改造成mini-batch的LSTM, 两篇比较有用的博文,转载mark

    https://zhuanlan.zhihu.com/p/34418001

    http://www.cnblogs.com/lindaxin/p/8052043.html

    一、为什么RNN需要处理变长输入

    假设我们有情感分析的例子,对每句话进行一个感情级别的分类,主体流程大概是下图所示:

    思路比较简单,但是当我们进行batch个训练数据一起计算的时候,我们会遇到多个训练样例长度不同的情况,这样我们就会很自然的进行padding,将短句子padding为跟最长的句子一样。

    比如向下图这样:

    但是这会有一个问题,什么问题呢?比如上图,句子“Yes”只有一个单词,但是padding了5的pad符号,这样会导致LSTM对它的表示通过了非常多无用的字符,这样得到的句子表示就会有误差,更直观的如下图:

    那么我们正确的做法应该是怎么样呢?

    这就引出pytorch中RNN需要处理变长输入的需求了。在上面这个例子,我们想要得到的表示仅仅是LSTM过完单词"Yes"之后的表示,而不是通过了多个无用的“Pad”得到的表示:如下图:

    二、pytorch中RNN如何处理变长padding

    主要是用函数torch.nn.utils.rnn.pack_padded_sequence()以及torch.nn.utils.rnn.pad_packed_sequence()来进行的,分别来看看这两个函数的用法。

    这里的pack,理解成压紧比较好。 将一个 填充过的变长序列 压紧。(填充时候,会有冗余,所以压紧一下)

    输入的形状可以是(T×B×* )。T是最长序列长度,B是batch size,*代表任意维度(可以是0)。如果batch_first=True的话,那么相应的 input size 就是 (B×T×*)。

    Variable中保存的序列,应该按序列长度的长短排序,长的在前,短的在后(特别注意需要进行排序)。即input[:,0]代表的是最长的序列,input[:, B-1]保存的是最短的序列。

    参数说明:

    input (Variable) – 变长序列 被填充后的 batch

    lengths (list[int]) – Variable 中 每个序列的长度。(知道了每个序列的长度,才能知道每个序列处理到多长停止

    batch_first (bool, optional) – 如果是True,input的形状应该是B*T*size。

    返回值:

    一个PackedSequence 对象。一个PackedSequence表示如下所示:

    具体代码如下:

    embed_input_x_packed = pack_padded_sequence(embed_input_x, sentence_lens, batch_first=True)
    encoder_outputs_packed, (h_last, c_last) = self.lstm(embed_input_x_packed)
    

    此时,返回的h_last和c_last就是剔除padding字符后的hidden state和cell state,都是Variable类型的。代表的意思如下(各个句子的表示,lstm只会作用到它实际长度的句子,而不是通过无用的padding字符,下图用红色的打钩来表示):

    但是返回的output是PackedSequence类型的,可以使用:

    encoder_outputs, _ = pad_packed_sequence(encoder_outputs_packed, batch_first=True)
    

    将encoderoutputs在转换为Variable类型,得到的_代表各个句子的长度。

    三、总结

    这样综上所述,RNN在处理类似变长的句子序列的时候,我们就可以配套使用torch.nn.utils.rnn.pack_padded_sequence()以及torch.nn.utils.rnn.pad_packed_sequence()来避免padding对句子表示的影响

    PackedSequence对象有一个很不错的特性,就是我们无需对序列解包(这一步操作非常慢)即可直接在PackedSequence数据变量上执行许多操作。特别是我们可以对令牌执行任何操作(即对令牌的顺序/上下文不敏感)。当然,我们也可以使用接受PackedSequence作为输入的任何一个pyTorch模块(pyTorch 0.2)。

    2、torch.nn.utils.rnn.pack_padded_sequence()

    这里的pack,理解成压紧比较好。 将一个 填充过的变长序列 压紧。(填充时候,会有冗余,所以压紧一下)

    输入的形状可以是(T×B×* )。T是最长序列长度,Bbatch size*代表任意维度(可以是0)。如果batch_first=True的话,那么相应的 input size 就是 (B×T×*)

    Variable中保存的序列,应该按序列长度的长短排序,长的在前,短的在后。即input[:,0]代表的是最长的序列,input[:, B-1]保存的是最短的序列。

    NOTE: 只要是维度大于等于2的input都可以作为这个函数的参数。你可以用它来打包labels,然后用RNN的输出和打包后的labels来计算loss。通过PackedSequence对象的.data属性可以获取 Variable

    参数说明:

    • input (Variable) – 变长序列 被填充后的 batch

    • lengths (list[int]) – Variable 中 每个序列的长度。

    • batch_first (bool, optional) – 如果是True,input的形状应该是B*T*size

    返回值:

    一个PackedSequence 对象。

    3、torch.nn.utils.rnn.pad_packed_sequence()

    填充packed_sequence

    上面提到的函数的功能是将一个填充后的变长序列压紧。 这个操作和pack_padded_sequence()是相反的。把压紧的序列再填充回来。

    返回的Varaible的值的size是 T×B×*T 是最长序列的长度,B 是 batch_size,如果 batch_first=True,那么返回值是B×T×*

    Batch中的元素将会以它们长度的逆序排列。

    参数说明:

    • sequence (PackedSequence) – 将要被填充的 batch

    • batch_first (bool, optional) – 如果为True,返回的数据的格式为 B×T×*

    返回值: 一个tuple,包含被填充后的序列,和batch中序列的长度列表。

    例子:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    import torch
    import torch.nn as nn
    from torch.autograd import Variable
    from torch.nn import utils as nn_utils
    batch_size = 2
    max_length = 3
    hidden_size = 2
    n_layers =1
     
    tensor_in = torch.FloatTensor([[1, 2, 3], [1, 0, 0]]).resize_(2,3,1)
    tensor_in = Variable( tensor_in ) #[batch, seq, feature], [2, 3, 1]
    seq_lengths = [3,1] # list of integers holding information about the batch size at each sequence step
     
    # pack it
    pack = nn_utils.rnn.pack_padded_sequence(tensor_in, seq_lengths, batch_first=True)
     
    # initialize
    rnn = nn.RNN(1, hidden_size, n_layers, batch_first=True)
    h0 = Variable(torch.randn(n_layers, batch_size, hidden_size))
     
    #forward
    out, _ = rnn(pack, h0)
     
    # unpack
    unpacked = nn_utils.rnn.pad_packed_sequence(out)
    print('111',unpacked)

     输出:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    111 (Variable containing:
    (0 ,.,.) =
      0.5406  0.3584
     -0.1403  0.0308
     
    (1 ,.,.) =
     -0.6855 -0.9307
      0.0000  0.0000
    [torch.FloatTensor of size 2x2x2]
    , [2, 1])

     


  • 相关阅读:
    JavaScript实现常用的排序算法
    jQuery学习之路(8)- 表单验证插件-Validation
    jQuery学习之路(7)- 用原生JavaScript实现jQuery的某些简单功能
    jQuery学习之路(6)- 简单的表格应用
    jQuery学习之路(5)- 简单的表单应用
    jQuery学习之路(4)- 动画
    JavaScript常见的五种数组去重的方式
    jQuery学习之路(3)- 事件
    jQuery学习之路(2)-DOM操作
    Docker使用非root用户
  • 原文地址:https://www.cnblogs.com/Arborday/p/9651105.html
Copyright © 2011-2022 走看看