zoukankan      html  css  js  c++  java
  • jupyter notebook加载DDP预训练模型

    最近遇到了一个问题,模型是用DistributedDataParallel一机多卡分布式训练的,然后作为一个jupyter notebook重度用户,我想用它来加载这个模型,搞点预测的例子可视化看看。

    但是这会碰到一个问题,我们都知道通常加载预训练模型的方法是:

    pretrained_dict = torch.load(pretrained_path, map_location=device)
    
    model.load_state_dict(pretrained_dict,strict=True)

    但是要想load_state_dict在DDP下训练的模型参数,首先初始化的模型model也需要在DDP下初始化,而我尝试了很久,发现没法在jupyter上初始化分布式环境:

    torch.distributed.init_process_group(backend='nccl',
                                                 init_method='env://')

    然后想了一个解决办法,查看一下torch.load之后的pretrained_dict字典参数,其中有很多项内容,可以看下我保存模型的时候:

    save_checkpoint({
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_score': best_prec1,
                    'optimizer': optimizer.state_dict(),
                }, is_best, work_dir = args.work_dir)

    所以想要加载模型参数,首先就要取出'state_dict',然后看下model.state_dict()里的数据结构,发现参数变量名是套在module.model下的,而我们初始化的模型结构model, 其model.state_dict()参数变量是直接model.XX的,所以就把预训练模型的参数变量名过滤掉moduel,然后用初始化模型model去load_state_dict它就好了;

    整体代码:

    model = XXX #初始化模型结构
    pretrained_path = 'XXX'
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    pretrained_dict = torch.load(pretrained_path, map_location=device)['state_dict']
    pretrained_dict = {k[7:]:v for k,v in pretrained_dict.items()} #k[X:]看情况调整
    
    model.load_state_dict(pretrained_dict,strict=True)

    注意,load_state_dict里的参数strict还是需要True来严格对齐,如果False的话,预训练的模型参数就会不严格加载,导致后续性能出现偏差。

    人生苦短,何不用python
  • 相关阅读:
    [网鼎杯 2018]Comment-Git泄露部分
    Google Hacking 详解
    macOS修改Docker容器的端口映射配置
    CentOS6 7 8更换阿里yum源
    XSS代码合集(含测试效果详细版)-HTML4与更早版本的向量2
    VMware 启动Ubuntu时黑屏
    XSS代码合集(含测试效果详细版)-HTML4与更早版本的向量1
    APP安全在线检测网站
    Juice-Shop 二星题
    慕课网-安卓工程师初养成-4-5 练习题
  • 原文地址:https://www.cnblogs.com/yqpy/p/14962338.html
Copyright © 2011-2022 走看看