zoukankan      html  css  js  c++  java
  • pytorch单机多卡并行计算示例

    一个简单的例子。

    注意:

    os.environ['MASTER_ADDR'] = 'xxx.xx.xx.xxx' # 这里填写电脑的IP地址
    os.environ['MASTER_PORT'] = '29555' # 空闲端口

    这两个参数似乎必须提前给出,选择的初始化方法为init_method="env://"(默认的环境变量方法)

    # 单机多卡并行计算示例
    
    import os
    import torch
    import torch.distributed as dist
    import torch.multiprocessing as mp
    import torch.nn as nn
    import torch.optim as optim
    from torch.nn.parallel import DistributedDataParallel as DDP
    
    
    # https://pytorch.org/docs/stable/notes/ddp.html
    
    
    def example(local_rank, world_size): # local_rank由mp.spawn自动给出
        # create default process group
        dist.init_process_group(backend="gloo", init_method="env://", rank=local_rank, world_size=world_size)
        # create local model
        model = nn.Linear(10, 10).cuda(local_rank)
        # construct DDP model
        ddp_model = DDP(model, device_ids=[local_rank], output_device=local_rank)
        # define loss function and optimizer
        loss_fn = nn.MSELoss()
        optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
    
        # forward pass
        for i in range(100):
            if local_rank == 0: # 这里开几个进程就会打印几次
                print(i)
            outputs = ddp_model(torch.randn(20, 10).cuda(local_rank))
            labels = torch.randn(20, 10).cuda(local_rank)
            # backward pass
            loss_fn(outputs, labels).backward()
            # update parameters
            optimizer.step()
    
    
    def main():
        os.environ['MASTER_ADDR'] = 'xxx.xx.xx.xxx' # 这里填写电脑的IP地址
        os.environ['MASTER_PORT'] = '29555' # 空闲端口
        world_size = torch.cuda.device_count()
        mp.spawn(example, args=(world_size,), nprocs=world_size, join=True)
    
    
    
    if __name__=="__main__":
        main()
        print('Done!')
    

      

    快去成为你想要的样子!
  • 相关阅读:
    Nginx配置文件nginx.conf中文详解
    tomcat nginx默许的post大小限制
    Unrecognized Windows Sockets error: 0: JVM_Bind 异常解决办法
    服务器被上传非法文件,查找命令
    jQuery Event.which 属性详解
    jQuery中$.fn的用法示例介绍
    Spring4 学习教程
    注意Hibernate4在开发当中的一些改变
    ubuntu PATH 出错修复
    SpringMVC与SiteMesh
  • 原文地址:https://www.cnblogs.com/jiangkejie/p/14606915.html
Copyright © 2011-2022 走看看