zoukankan      html  css  js  c++  java
  • Social GAN代码要点记录

    近日在阅读Social GAN文献的实验代码,加深对模型的理解,发现源代码的工程化很强,也比较适合构建实验模型的学习,故细致阅读。下文是笔者阅读中一些要点总结,有关于pytorch,也有关于模型自身的。

    GPU -> CPU

    SGAN的实验代码在工程化方面考虑比较充分,考虑到了在CPU和GPU两种平台上模型的运行。原生平台是GPU,若要切换为CPU,需要做如下改动(目前只改动了训练过程所需的,测试评估还未进行,但估计类似):

    1. args.use_gpu需要置为0,以保证int_dtypefloat_dtype不是cuda。
    2. 检索cuda(),可以发现在model.py还有些残缺未考虑的cuda定义,使用torch.cuda.is_available()判断是否GPU可使用,只有可行采用cuda()定义:
    x = xxx()
    if torch.cuda.is_available():
    	x = x.cuda() 
    

    池化层实现细节

    Social GAN相较于Social LSTM提出了新的池化模型以满足不同行人轨迹间信息共享与相互作用,具体有以下几个方面的变动:

    1. Social GAN的池化频率为一次,只在利用已知轨迹编码后进行一次池化。(代码中一个额外选项是在预测的每一步都进行池化)
    2. 池化范围为全局而不是固定的范围区间,代码使用max pooling的手段使得在场景人数不确定的情况下可以保持数据维度固定。
    3. 池化输入数据由两方面组成:LSTMs的隐藏状态+最后位置的相对信息

    而在代码实现时,计算相对位置信息时显得比较巧妙,例如在同场景的行人位置信息,代码通过两次不同的repeat策略将原有N个人的位置信息重复N次,从而形成了[P0, P0, P0, ...] [P1, P1, P1, ...] ... 和 [P0, P1, P2, ...] [P0, P1, P2, ...] ..两个矩阵,通过矩阵相减即可得到一个N*N行的矩阵,第(i)行是第(i \% N)个人相对于第(i / N)个人的相对位置。

    	curr_hidden = h_states.view(-1, self.h_dim)[start:end]
        curr_end_pos = end_pos[start:end]
    
        # Repeat -> [H1, H2, H3, ...][H1, H2, H3, ...]...
        curr_hidden_1 = curr_hidden.repeat(num_ped, 1)
        # Repeat position -> [P1, P2, P3, ...][P1, P2, P3, ...]...
        curr_end_pos_1 = curr_end_pos.repeat(num_ped, 1)
        # Repeat position -> [P1, P1, P1, ...][P2, P2, P2, ...]...
        curr_end_pos_2 = self.repeat(curr_end_pos, num_ped)
        # 得到行人的end_pos间的相对关系,并交给感知机去具体处理。
        # 每个行人与其他行人的相对位置关系由num_ped项,合计有num_ped**2项。
        curr_rel_pos = curr_end_pos_1 - curr_end_pos_2
        curr_rel_embedding = self.spatial_embedding(curr_rel_pos)
    
        # 拼接H_i和处理过的pos,放入多层感知机,最后经过maxPooling。
        mlp_h_input = torch.cat([curr_rel_embedding, curr_hidden_1], dim=1)
        curr_pool_h = self.mlp_pre_pool(mlp_h_input)
    

    DataLoader相关

    安利一个知乎,上面对使用Pytorch实现dataLoader解释得很细致

    https://zhuanlan.zhihu.com/p/30385675

    dataLoader迭代器的数据格式

    Dataset继承而来的TrajectoryDataSet__get_item__进行了重写,以方便dataLoader使用并整合,每次函数返回的是一个列表:

    out = [
        self.obs_traj[start:end, :], self.pred_traj[start:end, :],
        self.obs_traj_rel[start:end, :], self.pred_traj_rel[start:end, :],
        self.non_linear_ped[start:end], self.loss_mask[start:end, :]
    ]
    return out
    

    列表中有6个元素,以obs_traj为例,其大小为[N][2][seq_len],但是在使用dataLoader进行迭代时出现了这种形式,不仅一个batch中解压得到的变为7个,而且obs_traj的大小变为[seq_len][batch][2],,顺序发生了变化.

    batch = [tensor.cuda() for tensor in batch]
    (obs_traj, pred_traj_gt, obs_traj_rel, pred_traj_gt_rel, non_linear_ped,loss_mask, seq_start_end) = batch
    

    Solution

    1. 问题主要是忽视了DataLoadercollate_fn函数的作用,这个函数是在trajectory.py中自定义的函数,主要作用时当dataLoader收集到batch_sizeitem后形成一个列表,而后交由自定义的collate_fn做预处理,处理后的数据就会被输出为batch
    2. seq_collate解答了数据格式的两个疑问,包括使用permutecat函数。

    从dataLoader获取的batch数据的概念辨析

    Solution

    1. batch != batch_size

      1. 模型注释中有多处使用batch来表示张量格式,一个batch的数据常常有batch_size行,但在该模型中不成立。
      2. 严格来说,一个batch中有batch_sizeitem,但一个item可以用多行表示,这就是该模型的数据特点,其在一个batch中额外新增了seq_start_end列表(len(seq_start_end) == batch_size),使用该列表即可抽取出一个item

      [batch = Sigma_{i=0}^{batch\_size-1}N_i (N_i ge min\_peds) ]

      (N_i)表示一个场景下的行人个数。

    2. 一个batch中有多场景的行人轨迹数据

      1. LSTM编码和译码:每个轨迹都是独立的,此时可以整个batch一起处理
      2. 池化:设计同一场景下各行人序列数据交互,需要使用seq_start_end划分场景分别计算。
  • 相关阅读:
    Navicat连接mysql提示1251解决方案
    js获取select下拉框选中的值
    Windows下安装Mysql数据库
    ASP.NET MVC API以及.Core API进行安全拦截和API请求频率控制
    myeclipse 10.7中文破解版 下载安装看着一篇就够了
    Runtime exception at 0x004000bc: invalid integer input (syscall 5)
    MARS(MIPS汇编程序和运行时模拟器)
    如何将本地的代码上传到github
    JavaWeb基础
    大学什么时候开学?这款小程序告诉你!
  • 原文地址:https://www.cnblogs.com/sinoyou/p/11370554.html
Copyright © 2011-2022 走看看