zoukankan      html  css  js  c++  java
  • Pytorch DistributedDataParallel简明使用指南

    DistributedDataParallel(简称DDP)是PyTorch自带的分布式训练框架, 支持多机多卡和单机多卡, 与DataParallel相比起来, DDP实现了真正的多进程分布式训练.

    [原创][深度][PyTorch] DDP系列第一篇:入门教程
    当代研究生应当掌握的并行训练方法(单机多卡)

    DDP的原理和细节推荐上述两篇文章, 本文的主要目的是简要归纳如何在PyTorch代码中添加DDP的部分, 实现单机多卡分布式训练.

    Import部分:

    import numpy as np
    import torch
    import random
    
    import argparse
    import torch.distributed as dist
    from torch.nn.parallel import DistributedDataParallel as DDP
    from torch.utils.data.distributed import DistributedSampler
    

    在使用DDP训练的过程中, 代码需要知道当前进程是在哪一块GPU上跑的, 这里对应的本地进程序号local_rank(区别于多机多卡时的全局进程序号, 指的是一台机器上的进程序号)是由DDP自动从外部传入的, 我们使用argparse获取该参数即可.

    parser = argparse.ArgumentParser(description='Network Parser')
    args = parser.parse_args()
    local_rank = args.local_rank
    

    获取到local_rank后, 我们可以对模型进行初始化或加载等操作, 注意这里torch.load()要添加map_location参数, 否则可能导致读取进来的数据全部集中在0卡上. 模型构建完以后, 再将模型转移到DDP上:

    torch.cuda.set_device(local_rank)
    model = YourModel()
    # 如果需要加载模型
    if args.resume_path: 
    	checkpoint = torch.load(args.resume_path, map_location=torch.device("cpu"))  
    	model.load_state_dict(checkpoint["state_dict"])
    
    # 要在模型初始化或加载完后再进行
    # SyncBatchNorm不是必选项, 可以将模型中的BatchNorm层转换为进程之间同步数据的SyncBatchNorm层, 从而缓解Batch size较小时BN效果差的问题
    model = nn.SyncBatchNorm.convert_sync_batchnorm(model).cuda() 
    model = DDP(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True)
    

    这里有一个小细节, 假如你的model中用到了随机数种子来保证可复现性, 那么此时我们不能再用固定的常数作为seed, 否则会导致DDP中的所有进程都拥有一样的seed, 进而生成同态性的数据:

    def init_seeds(seed=0, cuda_deterministic=True):
    	random.seed(seed)
    	np.random.seed(seed)
    	torch.manual_seed(seed)
    	if cuda_deterministic:  # slower, more reproducible
    		cudnn.deterministic = True
    		cudnn.benchmark = False
    	else:  # faster, less reproducible
    	cudnn.deterministic = False
    	cudnn.benchmark = True
    
    def main():
    	random_seed = 1234
    	init_seeds(random_seed+local_rank)
    

    model部分处理完以后, 构建optimizer

    optimizer = YourOptimizer()
    if args.resume_path:
    	optimizer.load_state_dict(checkpoint["optimizer"])
    

    对于dataloader部分, 为了让DDP中同时运行的多个进程使用不同数据, 我们需要引入一个专用的sampler. 注意这里无需再指定shuffle=True, 因为sampler会在后续的set_epoch()帮我们打乱数据.

    args.batch_size = args.batch_size // torch.cuda.device_count  # 这一步是因为我传入的参数里batch_size代表所有GPU的batch之和, 所以要除以GPU的数量
    train_dataset = YourDataset()
    train_sampler = DistributedSampler(train_dataset)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, sampler=train_sampler, pin_memory=True, drop_last=True)
    

    对于val_loader, 一般不需要使用上述sampler, 只要保留原始的dataloader代码即可.

    模型和数据都准备好了, 接下来只需简单的操作:

    for epoch in range(0, epoch_num):
    	train_sampler.set_epoch(epoch)  # 这一步是为了让数据shuffle在每一个epoch都正常工作
    	train()
    

    在DDP训练的时候, 我个人习惯的流程是print, tensorboard, validate, 模型sava这类操作都只让其中一个进程完成, 这样可以避免冗余信息的出现. 注意DDP的时候保存的是model.module.state_dict().

    from torch.utils.tensorboard import SummaryWriter
    # 保存模型
    if dist.get_rank() == 0:
    	validate()
    	state_checkpoint = {
    			'state_dict': model.module.state_dict(), 
    			'optimizer':optimizer.state_dict()}
    	torch.save(state_checkpoint, model_name)
    )
    # 打印信息并在tensorboard绘制
    def main():
    	if dist.get_rank() == 0:
    		writer = SummaryWriter()
    	else:
    		writer = None
    	train(your_args, writer)
    
    def train(your_args, writer):
    	if dist.get_rank() == 0:
    		writer.add_scalar('Train/Loss', your_value, global_step=your_step)
    		print("Your information")
    

    注意这里创建tensorboard.SummaryWriter的进程和后续写入的进程要统一, 假如所有进程都创建, 只有一个进程写入的话, 会导致tensorboard不显示数据, 假如所有进程都创建, 所有进程都写入的话, 会导致tensorboard上出现许多条线(多个进程同时写入)

    将上述代码补充到原来的训练代码中后, 就可以模型愉快地进行DDP训练了, 这里nproc_per_node是使用的GPU数量, 我的代码中batch_size指的是所有GPU上的batch总和, 使用了4张卡, 所以实际上每张GPU上的mini-batch=8

    CUDA_VISIBLE_DEVICES="0,1,2,3" nohup python3 -u -m torch.distributed.launch --nproc_per_node=4 train.py --batch_size 32 --your_args "your args here" > log.out 2>&1 &
    
  • 相关阅读:
    DAOFactory复用代码
    WebUtils复用代码【request2Bean、UUID】
    过滤器复用代码【中文乱码、HTML转义】
    数据库复用代码【c3p0配置文件、数据库连接池】
    分页复用代码【Page类、JSP显示页面】
    AJAX应用【股票案例】
    JavaScript中的for in循环
    JSON【介绍、语法、解析JSON】
    javaScript【创建对象、创建类、成员变量、方法、公有和私有、静态】
    DOM【介绍、HTML中的DOM、XML中的DOM】
  • 原文地址:https://www.cnblogs.com/limitlessun/p/14716406.html
Copyright © 2011-2022 走看看