zoukankan      html  css  js  c++  java
  • 计图MPI分布式多卡

    计图MPI分布式多卡

    计图分布式基于MPI(Message Passing Interface),主要阐述使用计图MPI,进行多卡和分布式训练。目前计图分布式处于测试阶段。

    计图MPI安装

    计图依赖OpenMPI,用户可以使用如下命令安装OpenMPI:

    sudo apt install openmpi-bin openmpi-common libopenmpi-dev

    计图会自动检测环境变量中是否包含mpicc,如果计图成功的检测到了mpicc,输出如下信息:

    [i 0502 14:09:55.758481 24 __init__.py:203] Found mpicc(1.10.2) at /usr/bin/mpicc

    如果计图没有在环境变量中找到mpi,用户也可以手动指定mpicc的路径告诉计图,添加环境变量即可:export mpicc_path=/you/mpicc/path

    OpenMPI安装完成以后,用户无需修改代码,需要做的仅仅是修改启动命令行,计图就会用数据并行的方式,自动完成并行操作。

    # 单卡训练代码

    python3.7 -m jittor.test.test_resnet

    # 分布式多卡训练代码

    mpirun -np 4 python3.7 -m jittor.test.test_resnet

    # 指定特定显卡的多卡训练代码

    CUDA_VISIBLE_DEVICES="2,3" mpirun -np 2 python3.7 -m jittor.test.test_resnet

    便捷性的背后,计图的分布式算子的支撑,计图支持的mpi算子后端会使用nccl进行进一步的加速。计图所有分布式算法的开发,均在Python前端完成,让分布式算法的灵活度增强,开发分布式算法的难度也大大降低。

    基于这些mpi算子接口,研发团队已经集成了如下三种分布式相关的算法:

    • 分布式数据并行加载
    • 分布式优化器
    • 分布式同步批归一化层

    用户在使用MPI进行分布式训练时,计图内部的Dataset类会自动并行分发数据,需要注意的是Dataset类中设置的Batch size是所有节点的batch size之和,也就是总batch size,不是单个节点接收到的batch size。

    MPI接口

    目前MPI开放接口如下:

    • jt.mpi: 计图的MPI模块,当计图不在MPI环境下时,jt.mpi == None, 用户可以用这个判断是否在mpi环境下。
    • jt.Module.mpi_param_broadcast(root=0): 将模块的参数从root节点广播给其他节点。
    • jt.mpi.mpi_reduce(x, op='add', root=0): 将所有节点的变量x使用算子op,reduce到root节点。如果op是’add’或者’sum’,该接口会把所有变量求和,如果op是’mean’,该接口会取均值。

     

    • jt.mpi.mpi_broadcast(x, root=0): 将变量x从root节点广播到所有节点。

     

    • jt.mpi.mpi_all_reduce(x, op='add'): 将所有节点的变量x使用一起reduce,并且吧reduce的结果再次广播到所有节点。如果op是’add’或者’sum’,该接口会把所有变量求和,如果op是’mean’,该接口会取均值。

     

    实例:MPI实现分布式同步批归一化层

    下面的代码是使用计图实现分布式同步批,归一化层的实例代码,在原来批归一化层的基础上,只需增加三行代码,就可以实现分布式的batch norm,添加的代码如下:

    # 将均值和方差,通过all reduce同步到所有节点

    if self.sync and jt.mpi:

        xmean = xmean.mpi_all_reduce("mean")

        x2mean = x2mean.mpi_all_reduce("mean")

    注:计图内部已经实现了同步的批归一化层,用户不需要自己实现

    分布式同步批归一化层的完整代码:

    class BatchNorm(Module):

        def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=None, is_train=True, sync=True):

            assert affine == None

     

            self.sync = sync

            self.num_features = num_features

            self.is_train = is_train

            self.eps = eps

            self.momentum = momentum

            self.weight = init.constant((num_features,), "float32", 1.0)

            self.bias = init.constant((num_features,), "float32", 0.0)

            self.running_mean = init.constant((num_features,), "float32", 0.0).stop_grad()

            self.running_var = init.constant((num_features,), "float32", 1.0).stop_grad()

     

        def execute(self, x):

            if self.is_train:

                xmean = jt.mean(x, dims=[0,2,3], keepdims=1)

                x2mean = jt.mean(x*x, dims=[0,2,3], keepdims=1)

                # 将均值和方差,通过all reduce同步到所有节点

                if self.sync and jt.mpi:

                    xmean = xmean.mpi_all_reduce("mean")

                    x2mean = x2mean.mpi_all_reduce("mean")

     

                xvar = x2mean-xmean*xmean

                norm_x = (x-xmean)/jt.sqrt(xvar+self.eps)

                self.running_mean += (xmean.sum([0,2,3])-self.running_mean)*self.momentum

                self.running_var += (xvar.sum([0,2,3])-self.running_var)*self.momentum

            else:

                running_mean = self.running_mean.broadcast(x, [0,2,3])

                running_var = self.running_var.broadcast(x, [0,2,3])

                norm_x = (x-running_mean)/jt.sqrt(running_var+self.eps)

            w = self.weight.broadcast(x, [0,2,3])

            b = self.bias.broadcast(x, [0,2,3])

            return norm_x * w + b

    人工智能芯片与自动驾驶
  • 相关阅读:
    LeetCode 121. Best Time to Buy and Sell Stock
    LeetCode 221. Maximal Square
    LeetCode 152. Maximum Product Subarray
    LeetCode 53. Maximum Subarray
    LeetCode 91. Decode Ways
    LeetCode 64. Minimum Path Sum
    LeetCode 264. Ugly Number II
    LeetCode 263. Ugly Number
    LeetCode 50. Pow(x, n)
    LeetCode 279. Perfect Squares
  • 原文地址:https://www.cnblogs.com/wujianming-110117/p/14394895.html
Copyright © 2011-2022 走看看