zoukankan      html  css  js  c++  java
  • pytorch 对变长序列的处理

    一开始写这篇随笔的时候还没有了解到 Dateloader有一个 collate_fn 的参数,通过定义一个collate_fn 函数,其实很多batch补齐到当前batch最长的操作可以放在collate_fn 里面去,这样代码在训练和模型中就可以更加简洁。有时间再整理一下这个吧。

    _________________________________________

    使用的主要部分包括:Dateset、 Dateloader、MSELoss、PackedSequence、pack_padded_sequence、pad_packed_sequence

    模型包含LSTM模块。

    参考了下面两篇博文,总结了一下。对PackedSequence相关的理解可以先看这两篇。本文主要是把这些应用从数据准备到loss计算都串起来大致提供了一下代码思路,权当给自己的提醒备份吧。或者看完下面两篇,但是不知道具体怎么操作的朋友们一个参考。

    http://www.cnblogs.com/lindaxin/p/8052043.html#commentform

    https://blog.csdn.net/lssc4205/article/details/79474735

    使用Dateset构建数据集的时候,在__getitem__函数中把所有数据先补齐到 全局最长序列的长度。

    def __getitem__(self, index):
        '''
    get original data
    此处省略获取原始数据的代码
    input_data,output_data
    数据shape是  seq_length * feature_dim
        '''
    # 当前seq_length小于所有数据中的最长数据长度,则补0到同一长度。
        ori_length = input_data.shape[0]
        if ori_length < self.max_len:
            npi = np.zeros(self.input_feature_dim, dtype=np.float32)
            npi = np.tile(npi, (self.max_len - ori_length,1))
            input_data = np.row_stack((input_data, npi))
            npo = np.zeros(self.output_feature_dim, dtype=np.float32)
            npo = np.tile(npo, (self.max_len - ori_length,1))
            output_data = np.row_stack((output_data, npo))
        return input_data, output_data, ori_length, input_data_path

     在模型中,forward的实现中,需要在LSTM之前使用pack_padded_sequence、在LSTM之后使用pad_packed_sequence,中间还涉及到顺序的还原之类的操作。

    def forward(self, input_x, length_list, hidden=None):
        if hidden is None:
            # 这里没用 配置中的batch_size,而是直接在input_x中取batch_size是为了防止last_batch的batch_size不是配置中的那个,引发bug
            h_0 = input_x.data.new(self.directional*self.layer_num, input_x.shape[0], self.hidden_dim).fill_(0).float()
            c_0 = input_x.data.new(self.directional*self.layer_num, input_x.shape[0], self.hidden_dim).fill_(0).float()
        else:
            h_0, c_0 = hidden
        '''
    省略模型其他部分,直接进去LSTM前后的操作
        '''
        _, idx_sort = torch.sort(length_list, dim=0, descending=True)
        _, idx_unsort = otrch.sort(idx_sort, dim=0)
    
        input_x = input_x.index_select(0, Variable(idx_sort))
        length_list = list(length_list[idx_sort])
        pack = nn_utils.rnn.pack_padded_sequence(input_x, length_list, batch_first=self.batch_first)
        output, hidden = self.BiLSTM(pack, (h0, c0))
        un_padded = nn_utils.rnn.pad_packed_sequence(output, batch_first=self.batch_first)
        un_padded = un_padded[0].index_select(0, Variable(idx_unsort))
    # 此时的un_padded已经完成了还原,并且补0完成,而且这时的补0到的序列长度是当前batch的最长长度,而不是Dateset中的全局最长长度!
    # 所以在main train函数中也要对label的seq做处理
    return un_padded

     main train中,要对label做相应的截断处理,因为模型返回的长度已经是补齐到当前batch的最长序列长度了,而dateset返回的label是补齐到全局最长序列长度。算loss的时候,MSELoss的reduce参数要设置成false,让loss函数返回一个loss矩阵,再构造一个01掩膜矩阵mask,矩阵相乘求和得到真的loss(达到填充0的位置不参与loss的目的)

    def train(**kwargs):
      train_data = my_dataset()
      train_dataloader = DataLoader(train_data, opt.batch_size, shuffle=True, num_workers=opt.num_workers)
      model = getattr(models, opt.model)(batchsize=opt.batch_size)
      criterion = torch.nn.MSELoss(reduce=False)
      lr = opt.lf
      optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=opt.weight_decay)
      for epoch in range(opt.start_epoch, opt.max_epoch):
        for ii, (data, label, length_list,_) in tqdm(enumerate(train_dataloader)):
          cur_batch_max_len = length_list.max()
          data = Variable(data)
          target = Variable(label)

          optimizer.zero_grad()
          score = model(data, length_list)
          loss_mat = criterion(score, target)
          list_int = list(length_list)
          mask_mat = Variable(t.ones(len(list_int),cur_batch_max_len,opt.output_feature_dim))
          num_element = 0
          for idx_sample in range(len(list_int)):
            num_element += list_int[idx_sample] * opt.output_feature_dim
            if list_int[idx_sample] != cur_batch_max_len:
              mask_mat[idx_sample, list[idx_sample]:] = 0.0

          loss = (loss_mat * mask_mat).sum() / num_element
          loss.backward()
          optimizer.step()


  • 相关阅读:
    terminal
    变量提升、函数提升
    ssh传输文件
    mocha测试框架
    npm-run 自动化
    webpack
    浅析babel
    构建工具gulp
    C++中TRACE宏及assert()函数的使用
    memcpy函数-C语言
  • 原文地址:https://www.cnblogs.com/chengebigdata/p/8993990.html
Copyright © 2011-2022 走看看