zoukankan      html  css  js  c++  java
  • [PyTorch]论文pytorch复现中遇到的BUG

    1. zip argument #1 must support iteration

    在多gpu训练的时候,自动把你的batch_size分成n_gpu份,每个gpu跑一些数据, 最后再合起来。我之所以出现这个bug是因为返回的时候 返回了一个常量。。

    2. torch.nn.DataParallel

    在使用torch.nn.DataParallel时候,要先把模型放在gpu上,再进行parallel。

    3. model.state_dict()

    一般在现有的网络加载预训练模型通常是找到预训练模型在现有的model里面的参数,然后model进行更新,遇到一个bug, 发现加载预训练模型的时候, 效果很差,跟参数没有更新一样,找了一大顿bug,最后才发现,之前是单gpu进行的预训练,现在的模型使用的是多gpu, 打印现在模型的参数你会发现他所有的参数前面都加了一个module. 所以向以前一样更新,没有一个参数会被更新,因此写了一个万能模型参数加载函数。

    pretrained_dict = checkpoint['state_dict']
    model_dict = self.model.state_dict()
    if checkpoint['config']['n_gpu'] > 1 and self.config['n_gpu'] == 1:
        new_dict = OrderedDict()
        for k, v in pretrained_dict.items():
            name = k[7:]
            new_dict[name] = v
        pretrained_dict = new_dict
    elif checkpoint['config']['n_gpu'] == 1 and self.config['n_gpu'] > 1:
        new_dict = OrderedDict()
        for k, v in pretrained_dict.items():
            name = "module."+k
            new_dict[name] = v
        pretrained_dict = new_dict
    print("The pretrained model's para is following")
    for k, v in pretrained_dict.items():
        print(k)
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
    model_dict.update(pretrained_dict)
    self.model.load_state_dict(model_dict)
    
  • 相关阅读:
    【队列】队列的分类和实现
    【JSP】EL表达式语言
    【JSP】JSP的介绍和基本原理
    【JSP】JSP Action动作标签
    【Servlet】关于RequestDispatcher的原理
    【JSP】JSP指令
    【JSP】JSP中的Java脚本
    【算法】表达式求值--逆波兰算法介绍
    C语言指针详解
    移动架构-MVVM框架
  • 原文地址:https://www.cnblogs.com/kk17/p/10139884.html
Copyright © 2011-2022 走看看