zoukankan      html  css  js  c++  java
  • pytorch中如何在lstm中输入可变长的序列

    PyTorch 训练 RNN 时,序列长度不固定怎么办?

    pytorch中如何在lstm中输入可变长的序列

    上面两篇文章写得很好,把LSTM中训练变长序列所需的三个函数讲解的很清晰,但是这两篇文章没有给出完整的训练代码,并且没有写关于带label的情况,为此,本文给出一个完整的带label的训练代码

    import torch
    from torch import nn
    import torch.nn.utils.rnn as rnn_utils
    from torch.utils.data import DataLoader
    import torch.utils.data as data_
    
    
    class MyData(data_.Dataset):
        def __init__(self, data, label):
            self.data = data
            self.label = label
    
        def __len__(self):
            return len(self.data)
    
        def __getitem__(self, idx):
            tuple_ = (self.data[idx], self.label[idx])
            return tuple_
    
    
    def collate_fn(data_tuple):   # data_tuple是一个列表,列表中包含batchsize个元组,每个元组中包含数据和标签
        data_tuple.sort(key=lambda x: len(x[0]), reverse=True)
        data = [sq[0] for sq in data_tuple]
        label = [sq[1] for sq in data_tuple]
        data_length = [len(sq) for sq in data]
        data = rnn_utils.pad_sequence(data, batch_first=True, padding_value=0.0)     # 用零补充,使长度对齐
        label = rnn_utils.pad_sequence(label, batch_first=True, padding_value=0.0)   # 这行代码只是为了把列表变为tensor
        return data.unsqueeze(-1), label, data_length
    
    
    if __name__ == '__main__':
    
        EPOCH = 2
        batchsize = 3
        hiddensize = 4
        num_layers = 2
        learning_rate = 0.001
    
        # 训练数据
        train_x = [torch.FloatTensor([1, 1, 1, 1, 1, 1, 1]),
                   torch.FloatTensor([2, 2, 2, 2, 2, 2]),
                   torch.FloatTensor([3, 3, 3, 3, 3]),
                   torch.FloatTensor([4, 4, 4, 4]),
                   torch.FloatTensor([5, 5, 5]),
                   torch.FloatTensor([6, 6]),
                   torch.FloatTensor([7])]
        # 标签
        train_y = [torch.rand(7, hiddensize),
                   torch.rand(6, hiddensize),
                   torch.rand(5, hiddensize),
                   torch.rand(4, hiddensize),
                   torch.rand(3, hiddensize),
                   torch.rand(2, hiddensize),
                   torch.rand(1, hiddensize)]
    
        data_ = MyData(train_x, train_y)
        data_loader = DataLoader(data_, batch_size=batchsize, shuffle=True, collate_fn=collate_fn)
        net = nn.LSTM(input_size=1, hidden_size=hiddensize, num_layers=num_layers, batch_first=True)
        criteria = nn.MSELoss()
        optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)
    
        # 训练方法一
        for epoch in range(EPOCH):
            for batch_id, (batch_x, batch_y, batch_x_len) in enumerate(data_loader):
                batch_x_pack = rnn_utils.pack_padded_sequence(batch_x, batch_x_len, batch_first=True)
                out, _ = net(batch_x_pack)   # out.data's shape (所有序列总长度, hiddensize)
                out_pad, out_len = rnn_utils.pad_packed_sequence(out, batch_first=True)
                loss = criteria(out_pad, batch_y)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                print('epoch:{:2d}, batch_id:{:2d}, loss:{:6.4f}'.format(epoch, batch_id, loss))
    
        # 训练方法二
        for epoch in range(EPOCH):
            for batch_id, (batch_x, batch_y, batch_x_len) in enumerate(data_loader):
                batch_x_pack = rnn_utils.pack_padded_sequence(batch_x, batch_x_len, batch_first=True)
                batch_y_pack = rnn_utils.pack_padded_sequence(batch_y, batch_x_len, batch_first=True)
                out, _ = net(batch_x_pack)   # out.data's shape (所有序列总长度, hiddensize)
                loss = criteria(out.data, batch_y_pack.data)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                print('epoch:{:2d}, batch_id:{:2d}, loss:{:6.4f}'.format(epoch, batch_id, loss))
    
        print('Training done!')

    运行结果:

     

  • 相关阅读:
    Flex4之关于Embed外部资源的使用方法总结
    Flex tree的开发 后台数据连接
    Flex4之Tree开发
    Flex自定义组件开发
    解决AS3请求数据的“安全沙箱冲突”问题
    purMVC代码备份
    译:如何配置Red5应用程序
    关于RED5——配置文件详解
    关于socket使用Amf直接进行对象传输的
    垃圾回收机制
  • 原文地址:https://www.cnblogs.com/picassooo/p/13577527.html
Copyright © 2011-2022 走看看