zoukankan      html  css  js  c++  java
  • pytorch multi-gpu train

    记录一下pytorch如何进行单机多卡训练:

    官网例程:https://pytorch.org/tutorials/beginner/blitz/data_parallel_tutorial.html

    下面以一个例子讲解一下,例如现在总共有8张卡,在第5、6、7三张卡上进行训练;

    step 1:可视化需要用到的GPU

    import os

    os.environ["CUDA_VISIBLE_DEVICES"] = "5 , 6 , 7"

    device = torch.device("cuda:0")  #注意多卡训练的时候,默认都先将model和data先保存在id:0的卡上(即实际的第5块卡),然后model的参数会复制共享到其他卡上,data也会平分成若干个batch到其他卡上(所以第一块卡上稍微耗费一点显存);

    device_ids = [0 , 1 , 2] #注意device_ids必须从0开始,代码中的所有的device id都需要从0开始(这里的0代表第5块卡,1代表第6块卡,类推);

    step 2:利用DataParallel对Model类进行封装

    model = nn.DataParallel(model , device_ids = device_ids)

    model.to(device)

    step 3:

    data.to(device)  #id:0卡上的数据再被平分成若干个batch到其他卡上

    注意:晚上还有一些例程,需要对optimizer和loss利用DataParellel进行封装,没有试验过,但上面方法是参考官网例程,并经过实操考验;

  • 相关阅读:
    ES6 Promise 用法转载
    移动端滚动性能优化
    Python之禅
    Day01~15
    Python
    第一章 Java起源
    IMP-00009: 导出文件异常结束 imp
    浏览器访问网页的详细内部过程
    数据库连接池
    连接数据库 六大步骤
  • 原文地址:https://www.cnblogs.com/zf-blog/p/10599328.html
Copyright © 2011-2022 走看看