zoukankan      html  css  js  c++  java
  • pytorch中的pack_padded_sequence和pad_packed_sequence用法

    pack_padded_sequence是将句子按照batch优先的原则记录每个句子的词,变化为不定长tensor,方便计算损失函数。

    pad_packed_sequence是将pack_padded_sequence生成的结构转化为原先的结构,定长的tensor。

    其中test.txt的内容

    As they sat in a nice coffee shop, 
    he was too nervous to say anything and she felt uncomfortable. 
    Suddenly, he asked the waiter, 
    "Could you please give me some salt? I'd like to put it in my coffee."

    具体参见如下代码

    import torch
    import torch.nn as nn
    from torch.autograd import Variable
    import numpy as np
    import wordfreq
    
    vocab = {}
    token_id = 1
    lengths = []
    
    #读取文件,生成词典
    with open('test.txt', 'r') as f:
        lines=f.readlines()
        for line in lines:
            tokens = wordfreq.tokenize(line.strip(), 'en')
            lengths.append(len(tokens))
            #将每个词加入到vocab中,并同时保存对应的index
            for word in tokens:
                if word not in vocab:
                    vocab[word] = token_id
                    token_id += 1
    
    x = np.zeros((len(lengths), max(lengths)))
    l_no = 0
    #将词转化为数字
    with open('test.txt', 'r') as f:
        lines = f.readlines()
        for line in lines:
            tokens = wordfreq.tokenize(line.strip(), 'en')
            for i in range(len(tokens)):
                x[l_no, i] = vocab[tokens[i]]
            l_no += 1
    
    x=torch.Tensor(x)
    x = Variable(x)
    print(x)
    '''
    tensor([[ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  0.,  0.,  0.,  0.,  0.,  0.],
            [ 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19.,  0.,  0.,  0.],
            [20.,  9., 21., 22., 23.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
            [24., 25., 26., 27., 28., 29., 30., 31., 32., 13., 33., 34.,  4.,  7.]])
    '''
    lengths = torch.Tensor(lengths)
    print(lengths)#tensor([ 8., 11.,  5., 14.])
    
    _, idx_sort = torch.sort(torch.Tensor(lengths), dim=0, descending=True)
    print(_) #tensor([14., 11.,  8.,  5.])
    print(idx_sort)#tensor([3, 1, 0, 2])
    
    lengths = list(lengths[idx_sort])#按下标取元素 [tensor(14.), tensor(11.), tensor(8.), tensor(5.)]
    t = x.index_select(0, idx_sort)#按下标取元素
    print(t)
    '''
    tensor([[24., 25., 26., 27., 28., 29., 30., 31., 32., 13., 33., 34.,  4.,  7.],
            [ 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19.,  0.,  0.,  0.],
            [ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  0.,  0.,  0.,  0.,  0.,  0.],
            [20.,  9., 21., 22., 23.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.]])
    '''
    x_packed = nn.utils.rnn.pack_padded_sequence(input=t, lengths=lengths, batch_first=True)
    print(x_packed)
    '''
    PackedSequence(data=tensor([24.,  9.,  1., 20., 25., 10.,  2.,  9., 26., 11.,  3., 21., 27., 12.,
             4., 22., 28., 13.,  5., 23., 29., 14.,  6., 30., 15.,  7., 31., 16.,
             8., 32., 17., 13., 18., 33., 19., 34.,  4.,  7.]), batch_sizes=tensor([4, 4, 4, 4, 4, 3, 3, 3, 2, 2, 2, 1, 1, 1]))
    '''
    
    
    x_padded = nn.utils.rnn.pad_packed_sequence(x_packed, batch_first=True)#x_padded是tuple
    print(x_padded)
    '''
    (tensor([[24., 25., 26., 27., 28., 29., 30., 31., 32., 13., 33., 34.,  4.,  7.],
            [ 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19.,  0.,  0.,  0.],
            [ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  0.,  0.,  0.,  0.,  0.,  0.],
            [20.,  9., 21., 22., 23.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.]]), tensor([14, 11,  8,  5]))
    '''
    #还原tensor
    _, idx_unsort = torch.sort(idx_sort)
    output = x_padded[0].index_select(0, idx_unsort)
    print(output)
    '''
    tensor([[ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  0.,  0.,  0.,  0.,  0.,  0.],
            [ 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19.,  0.,  0.,  0.],
            [20.,  9., 21., 22., 23.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
            [24., 25., 26., 27., 28., 29., 30., 31., 32., 13., 33., 34.,  4.,  7.]])
    '''
  • 相关阅读:
    GeoServer源码之Dispatcher
    Geoserver开发之OWS是什么?
    GeoServer数据工作空间:怎么设置?
    Java注解:不使用注解的话,也能实现初始化bean吗?
    JSP编译错误无法显示:PWC6033: Unable to compile class for JSP
    Spring Bean初始化失败
    【每日一具17】CAD迷你画图/最新2020R9
    用python爬虫简单网站却有 “多重思路”--猫眼电影
    【每日一具17】CAD迷你画图/最新2020R9
    Python教你如何对 Excel(xlxs文件) 表的读写和处理
  • 原文地址:https://www.cnblogs.com/AntonioSu/p/12015141.html
Copyright © 2011-2022 走看看