zoukankan      html  css  js  c++  java
  • 分享某Python下的mpi教程 —— A Python Introduction to Parallel Programming with MPI 1.0.2 documentation ( 续 #2 )

    接前文:

    分享某Python下的mpi教程 —— A Python Introduction to Parallel Programming with MPI 1.0.2 documentation

    https://materials.jeremybejarano.com/MPIwithPython/collectiveCom.html

    Collective Communication

    Reduce(…) and Allreduce(…)

    例子:

    Reduce

    import numpy
    from mpi4py import MPI
    
    comm = MPI.COMM_WORLD
    rank = comm.Get_rank()
    size = comm.Get_size()
    
    rankF = numpy.array(float(rank))
    
    if rank == 0:
        total = numpy.zeros(1)
    else:
        total = None
    
    comm.Reduce(rankF, total, op=MPI.MAX)
    #comm.Reduce(rankF, total, op=MPI.SUM)
    
    if rank == 0:
        print("total: ", total)

    Allreduce

    import numpy
    from mpi4py import MPI
    
    comm = MPI.COMM_WORLD
    rank = comm.Get_rank()
    size = comm.Get_size()
    
    rankF = numpy.array(float(rank))
    
    total = numpy.zeros(1)
    
    comm.Allreduce(rankF, total, op=MPI.MAX)
    
    print("rank {}  :  total {} ".format(rank, total))

     Scatter

    # dotProductParallel_1.py
    # "to run" syntax example: mpiexec -n 4 python26 dotProductParallel_1.py 40000
    from mpi4py import MPI
    import numpy
    import sys
    
    comm = MPI.COMM_WORLD
    rank = comm.Get_rank()
    size = comm.Get_size()
    
    # read from command line
    # n = int(sys.argv[1])    #length of vectors
    n = 10000
    
    # arbitrary example vectors, generated to be evenly divided by the number of
    # processes for convenience
    
    x = numpy.linspace(0, 100, n) if comm.rank == 0 else None
    y = numpy.linspace(20, 300, n) if comm.rank == 0 else None
    
    # initialize as numpy arrays
    dot = numpy.array([0.])
    local_n = numpy.array([0], dtype=numpy.int32)
    
    # test for conformability
    if rank == 0:
        if n != y.size:
            print("vector length mismatch")
            comm.Abort()
    
        # currently, our program cannot handle sizes that are not evenly divided by
        # the number of processors
        if n % size != 0:
            print("the number of processors must evenly divide n.")
            comm.Abort()
    
        # length of each process's portion of the original vector
        local_n = numpy.array([n / size], dtype=numpy.int32)
    
    # communicate local array size to all processes
    comm.Bcast(local_n, root=0)
    
    # initialize as numpy arrays
    local_x = numpy.zeros(local_n)
    local_y = numpy.zeros(local_n)
    
    # divide up vectors
    comm.Scatter(x, local_x, root=0)
    comm.Scatter(y, local_y, root=0)
    
    # local computation of dot product
    local_dot = numpy.array([numpy.dot(local_x, local_y)])
    
    # sum the results of each
    #comm.Reduce(local_dot, dot, op=MPI.SUM)
    comm.Allreduce(local_dot, dot, op=MPI.SUM)
    
    
    print("The dot product is", dot[0], "computed in parallel")
    
    if rank == 0:
        #print("The dot product is", dot[0], "computed in parallel")
        print("and", numpy.dot(x, y), "computed serially")

    Scatterv(…) and Gatherv(…)

    # for correct performance, run unbuffered with 3 processes:
    # mpiexec -n 3 python26 scratch.py -u
    import numpy
    from mpi4py import MPI
    comm = MPI.COMM_WORLD
    rank = comm.Get_rank()
    
    
    if rank == 0:
        x_global = numpy.linspace(0,100,11)
    else:
        x_global = None
    
    
    if rank == 0:
        x_local = numpy.zeros(1)
    elif rank == 1:
        x_local = numpy.zeros(1)
    elif rank == 2:
        x_local = numpy.zeros(9)
    
    if rank == 0:
        print("Scatter")
    
    comm.Scatterv([x_global, (1,1,9), (0,1,2), MPI.DOUBLE], x_local)
    print("process " + str(rank) + " has " + str(x_local))
    
    comm.Barrier()
    
    if rank == 0:
        print("Gather")
        xGathered = numpy.zeros(11)
    else:
        xGathered = None
    
    comm.Gatherv(x_local, [xGathered, (1,1,9), (0,1,2), MPI.DOUBLE])
    
    print("process " + str(rank) + " has " +str(xGathered))

    该代码运行命令为:

    mpiexec -np 3 python x.py

    上个代码有个地方容易被忽视那就是  函数  comm.Scatterv  其实是非堵塞的,也就是说如果rank==0进程在执行该语句后不进行同步操作:comm.Barrier

    那么rank==0进程会继续向下执行而不会等待rank==1,rank==2进程完全接收数据到各自的变量  x_local  中。

    给出修改的代码:

    # for correct performance, run unbuffered with 3 processes:
    # mpiexec -n 3 python26 scratch.py -u
    import numpy
    from mpi4py import MPI
    comm = MPI.COMM_WORLD
    rank = comm.Get_rank()
    
    
    if rank == 0:
        x_global = numpy.linspace(0,100,11)
    else:
        x_global = None
    
    
    if rank == 0:
        x_local = numpy.zeros(1)
    elif rank == 1:
        x_local = numpy.zeros(1)
    elif rank == 2:
        x_local = numpy.zeros(9)
    
    if rank == 0:
        print("Scatter")
    
    if rank != 0:
        import time
        time.sleep(10)
    
    comm.Scatterv([x_global, (1,1,9), (0,1,2), MPI.DOUBLE], x_local)
    print("process " + str(rank) + " has " + str(x_local))
    
    #comm.Barrier()
    
    if rank == 0:
        print("Gather")
        xGathered = numpy.zeros(11)
    else:
        xGathered = None
        
    comm.Gatherv(x_local, [xGathered, (1,1,9), (0,1,2), MPI.DOUBLE])
    
    print("process " + str(rank) + " has " +str(xGathered))

    该代码执行后会先打印结果:

    Scatter
    process 0 has [0.]
    Gather

    然后进入堵塞大致10秒时间,由此可以看到不进行 comm.Barrier 操作的 comm.Scatterv 是非堵塞的,rank==0没有等待其他进程完全接收数据便向下执行了,但是从运行结果上我们可以看到收集操作 comm.Gatherv  是堵塞的,也正因此rank==0进程会在此处进入堵塞10秒的状态。

    由于上面的代码后续运行中有堵塞操作了,因此没有 comm.Barrier 操作也不会有问题,不过对于MPI中的非堵塞操作还是进行同步操作 comm.Barrier 操作以防万一的安全一些。

    该代码运行命令同样也为:

    mpiexec -np 3 python x.py

    ==================================================

    上面的代码其实是把进程数量硬编码进入代码里面了,如果运行不是 -np 3 而是其他数值则会报错,而这种编码方式是不妥的,因此不把进程数硬编码进去同时也能实现很好的计算负载是需要的。给出自己的工作:

     实现计算负载均衡的代码:

    # for correct performance, run unbuffered with 3 processes:
    # mpiexec -n 3 python26 scratch.py -u
    import numpy
    from mpi4py import MPI
    
    
    comm = MPI.COMM_WORLD
    rank = comm.Get_rank()
    size = comm.Get_size()
    
    
    n = 10000
    if rank == 0:
        x_global = numpy.linspace(0,100,n)
    else:
        x_global = None
    
    
    n_local = numpy.zeros(size, dtype=numpy.int32)
    n_local[:] = n // size
    if n%size != 0:
        n_local[-(n%size):] += 1
    
    begin_local = numpy.zeros(size)
    for i in range(1, size):
        begin_local[i] = begin_local[i-1] + n_local[i-1]
    
    x_local = numpy.zeros(n_local[rank])
    
    
    #if rank != 0:
    #    import time
    #    time.sleep(5)
    if rank == 0:
        print("Scatter")
    
    comm.Scatterv([x_global, n_local, begin_local, MPI.DOUBLE], x_local)
    print("process " + str(rank) + " has " + str(x_local[:5]))
    
    comm.Barrier()
    
    if rank == 0:
        print("Gather")
        xGathered = numpy.zeros(n)
    else:
        xGathered = None
        
    comm.Gatherv(x_local, [xGathered, n_local, begin_local, MPI.DOUBLE])
    
    print("process " + str(rank) + " has " +str(xGathered))

    运行命令:

    mpiexec -np 23 python  x.py

    可以看到改进后的代码没有把总共运行的进程数硬编码到代码中,而是可以根据实际需要对任意数值下的总进程数实现计算负载均衡。改进代码根据总的运行进程数将计算服务均衡的划分给所有计算进程。

     ================================================

    以上代码运行命令如无特殊说明则为:

    mpiexec -np 8 python x.py

    本博客是博主个人学习时的一些记录,不保证是为原创,个别文章加入了转载的源地址还有个别文章是汇总网上多份资料所成,在这之中也必有疏漏未加标注者,如有侵权请与博主联系。
  • 相关阅读:
    运动运行。
    stratMove方法
    抛物线
    表单的小例子吖
    常用的查询DOM的方法
    liuyan
    防止xss攻击。
    ES6
    Map的使用
    ZOJ 3998(线段树)
  • 原文地址:https://www.cnblogs.com/devilmaycry812839668/p/15142103.html
Copyright © 2011-2022 走看看