zoukankan      html  css  js  c++  java
  • pytorch单机多卡并行计算示例

    一个简单的例子。

    注意:

    os.environ['MASTER_ADDR'] = 'xxx.xx.xx.xxx' # 这里填写电脑的IP地址
    os.environ['MASTER_PORT'] = '29555' # 空闲端口

    这两个参数似乎必须提前给出,选择的初始化方法为init_method="env://"(默认的环境变量方法)

    # 单机多卡并行计算示例
    
    import os
    import torch
    import torch.distributed as dist
    import torch.multiprocessing as mp
    import torch.nn as nn
    import torch.optim as optim
    from torch.nn.parallel import DistributedDataParallel as DDP
    
    
    # https://pytorch.org/docs/stable/notes/ddp.html
    
    
    def example(local_rank, world_size): # local_rank由mp.spawn自动给出
        # create default process group
        dist.init_process_group(backend="gloo", init_method="env://", rank=local_rank, world_size=world_size)
        # create local model
        model = nn.Linear(10, 10).cuda(local_rank)
        # construct DDP model
        ddp_model = DDP(model, device_ids=[local_rank], output_device=local_rank)
        # define loss function and optimizer
        loss_fn = nn.MSELoss()
        optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
    
        # forward pass
        for i in range(100):
            if local_rank == 0: # 这里开几个进程就会打印几次
                print(i)
            outputs = ddp_model(torch.randn(20, 10).cuda(local_rank))
            labels = torch.randn(20, 10).cuda(local_rank)
            # backward pass
            loss_fn(outputs, labels).backward()
            # update parameters
            optimizer.step()
    
    
    def main():
        os.environ['MASTER_ADDR'] = 'xxx.xx.xx.xxx' # 这里填写电脑的IP地址
        os.environ['MASTER_PORT'] = '29555' # 空闲端口
        world_size = torch.cuda.device_count()
        mp.spawn(example, args=(world_size,), nprocs=world_size, join=True)
    
    
    
    if __name__=="__main__":
        main()
        print('Done!')
    

      

    快去成为你想要的样子!
  • 相关阅读:
    js学习笔记7----return,arguments及获取元素样式
    js学习笔记6----作用域及解析机制
    js学习笔记5----函数传参
    js学习笔记4----数据类型
    Flashtext 使用文档 大规模数据清洗的利器-实现文本结构化
    Linux之目录的操作(创建、移动、改名、删除、复制)
    Python 异常处理
    Python 内置模块函数filter reduce
    Python处理文件以及文件夹常用方法
    Python 字符串常用方法
  • 原文地址:https://www.cnblogs.com/jiangkejie/p/14606915.html
Copyright © 2011-2022 走看看