zoukankan      html  css  js  c++  java
  • Pytorch 多 GPU 并行处理机制

    Pytorch 的多 GPU 处理接口是 torch.nn.DataParallel(module, device_ids),其中 module 参数是所要执行的模型,而 device_ids 则是指定并行的 GPU id 列表。

    而其并行处理机制是,首先将模型加载到主 GPU 上,然后再将模型复制到各个指定的从 GPU 中,然后将输入数据按 batch 维度进行划分,具体来说就是每个 GPU 分配到的数据 batch 数量是总输入数据的 batch 除以指定 GPU 个数。每个 GPU 将针对各自的输入数据独立进行 forward 计算,最后将各个 GPU 的 loss 进行求和,再用反向传播更新单个 GPU 上的模型参数,再将更新后的模型参数复制到剩余指定的 GPU 中,这样就完成了一次迭代计算。所以该接口还要求输入数据的 batch 数量要不小于所指定的 GPU 数量。

     
     

    这里有两点需要注意:

    1. 主 GPU 默认情况下是 0 号 GPU,也可以通过 torch.cuda.set_device(id) 来手动更改默认 GPU。
    2. 提供的多 GPU 并行列表中需要包含有主 GPU。


    作者:叶俊贤
    链接:https://www.jianshu.com/p/9e36e5e36638
    来源:简书
    简书著作权归作者所有,任何形式的转载都请联系作者获得授权并注明出处。
  • 相关阅读:
    Supreme(ง •̀_•́)ง
    基于VS快速排序的单元测试
    POST GET
    Go对比其他语言新特性1(字符类型、类型转换、运算符、键盘输入、for、switch)
    四则运算问题
    软件工程第三次作业!
    Servlet
    结对编程1
    Kafka技术原理知识点总结
    KafkaStream简介
  • 原文地址:https://www.cnblogs.com/jfdwd/p/11197897.html
Copyright © 2011-2022 走看看