zoukankan      html  css  js  c++  java
  • pytorch 多gpu训练

    pytorch 多gpu训练

    用nn.DataParallel重新包装一下

    数据并行有三种情况

    前向过程

    device_ids=[0, 1, 2]
    model = model.cuda(device_ids[0])
    model = nn.DataParallel(model, device_ids=device_ids)
    

    只要将model重新包装一下就可以。

    后向过程

    optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=0.001)
    optimizer = nn.DataParallel(optimizer, device_ids=device_ids)
    #因为它在DataParallel里面,所以要先变成普通的nn.SGD对象,然后才能调用该类的梯度更新方法。
    optimizer.module.step() 
    

    在网上看到别人这样写了,做了一下测试。但是显存没有变化,不知道它的影响是怎样的。
    更新学习率的时候也需要注意一下:

    for param_lr in optimizer.module.param_groups: #同样是要加module
        param_lr['lr'] /= 2
    

    criterion(loss 函数)

    def init_criterion():
        criterion = loss.CrossEntropyLoss2d()
        criterion = torch.nn.DataParallel(
                criterion, range(gpu_nums)).cuda()  # range(self.settings.n_gpu)
        return criterion
        
    # criterion = init_criterion()
    criterion = loss.CrossEntropyLoss2d()
    

    这个并行的效果对显存是有影响的,但是效果不明显。我没有做太多实验。
    训练的时候会出现问题:

    loss = criterion(out, labels_tensor)
    loss /= N
    optimizer.zero_grad()
    # loss.backward()
    loss.sum().backward()
    

    数据并行返回的结果的维度和之前维度是不一样的所以反向传播的时候需要做一下修改

  • 相关阅读:
    linux试题
    linux常用脚本
    nagios
    lvs/nginx/haproxy 负载均衡优缺点分析讲解
    一次SSLPeerUnverifiedException,SSLHandshakeException问题的分析
    [转]【安卓笔记】AsyncTask源码剖析
    linux下查看进程占用端口和端口占用进程命令
    which framework or library is best to use WebRTC
    [转]svn diff 替代工具
    [转]使用Subversion进行版本控制
  • 原文地址:https://www.cnblogs.com/o-v-o/p/9975357.html
Copyright © 2011-2022 走看看