zoukankan      html  css  js  c++  java
  • pytorch单机多卡并行计算示例

    一个简单的例子。

    注意:

    os.environ['MASTER_ADDR'] = 'xxx.xx.xx.xxx' # 这里填写电脑的IP地址
    os.environ['MASTER_PORT'] = '29555' # 空闲端口

    这两个参数似乎必须提前给出,选择的初始化方法为init_method="env://"(默认的环境变量方法)

    # 单机多卡并行计算示例
    
    import os
    import torch
    import torch.distributed as dist
    import torch.multiprocessing as mp
    import torch.nn as nn
    import torch.optim as optim
    from torch.nn.parallel import DistributedDataParallel as DDP
    
    
    # https://pytorch.org/docs/stable/notes/ddp.html
    
    
    def example(local_rank, world_size): # local_rank由mp.spawn自动给出
        # create default process group
        dist.init_process_group(backend="gloo", init_method="env://", rank=local_rank, world_size=world_size)
        # create local model
        model = nn.Linear(10, 10).cuda(local_rank)
        # construct DDP model
        ddp_model = DDP(model, device_ids=[local_rank], output_device=local_rank)
        # define loss function and optimizer
        loss_fn = nn.MSELoss()
        optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
    
        # forward pass
        for i in range(100):
            if local_rank == 0: # 这里开几个进程就会打印几次
                print(i)
            outputs = ddp_model(torch.randn(20, 10).cuda(local_rank))
            labels = torch.randn(20, 10).cuda(local_rank)
            # backward pass
            loss_fn(outputs, labels).backward()
            # update parameters
            optimizer.step()
    
    
    def main():
        os.environ['MASTER_ADDR'] = 'xxx.xx.xx.xxx' # 这里填写电脑的IP地址
        os.environ['MASTER_PORT'] = '29555' # 空闲端口
        world_size = torch.cuda.device_count()
        mp.spawn(example, args=(world_size,), nprocs=world_size, join=True)
    
    
    
    if __name__=="__main__":
        main()
        print('Done!')
    

      

    快去成为你想要的样子!
  • 相关阅读:
    微信小程序wx.uploadFile 上传文件 的两个坑
    小程序 滚动wx.pageScrollTo
    scss定义全局变量引入sass-resources-loader报错
    mac 创建多个全局Path
    《node.js开发指南》partial is not defined的解决方案
    jq 将translate的旋转角度转为数值
    js浮点金额计算精度
    移动端页面弹窗滚动,页面也随之滚动解决方案
    js 禁止右击保存图片,禁止拖拽图片
    小程序md5加密
  • 原文地址:https://www.cnblogs.com/jiangkejie/p/14606915.html
Copyright © 2011-2022 走看看