zoukankan      html  css  js  c++  java
  • Pytorch Distributed 初始化

    Pytorch Distributed 初始化方法

    参考文献

    https://pytorch.org/docs/master/distributed.html

    代码
    https://github.com/overfitover/pytorch-distributed
    欢迎来star me.

    初始化

    torch.distributed.init_process_group(backend, init_method='env://', **kwargs)
    

    参数说明

    • backend(str): 后端选择,包括 tcp mpi gloo
    • init_method(str, optional): 用来初始化包的URL, 用来做并发控制的共享方式
    • world_size(int, optional): 参与工作的进程数
    • rank(int, optional): 当前进程的rank
    • group_name(str, optional): 用来标记这组进程。

    init_method()

    有三种方法:

    • file:// 共享文件系统
    • tcp:// IP组播
    • env:// 环境变量 (默认是这个)

    env

    #!/usr/bin/env python
    import os
    import torch
    import torch.distributed as dist
    from torch.multiprocessing import Process
    import time
    
    def run(rank, size):
        pass
    
    
    def init_processes(rank, size, fn, backend='gloo'):
        """ Initialize the distributed environment. """
        os.environ['MASTER_ADDR'] = '162.128.0.22'
        os.environ['MASTER_PORT'] = '29555'
        dist.init_process_group(backend, rank=rank, world_size=size)
        torch.cuda.manual_seed(1)
        fn(rank, size)
        print("MM")
        print(dist.get_rank())
        print(dist.get_world_size())
        print(dist.is_available())
    
    
    def main():
    
        size = 2
        processes=[]
        for i in range(size):
            p = Process(target=init_processes, args=(i, size, run))
            p.start()
            processes.append(p)
    
        for p in processes:
            p.join()
    
    if __name__ == "__main__":
        start_time = time.time()
        main()
    
        end_time = time.time()
        print("耗时:", end_time-start_time)
    
    

    注意
    将162.128.0.22换成自己的IP地址。

    tcp

    import torch
    import torch.distributed as dist
    import argparse
    from time import sleep
    from random import randint
    from torch.multiprocessing import Process
    
    
    def initialize(rank, world_size, ip, port):
        dist.init_process_group(backend='tcp', init_method='tcp://{}:{}'.format(ip, port), rank=rank, world_size=world_size)
        print("MM")
    
    def main():
        parser = argparse.ArgumentParser()
        parser.add_argument('--ip', type=str, default='162.128.0.22')
        parser.add_argument('--port', type=str, default='20000')
        parser.add_argument('--rank', '-r', type=int)
        parser.add_argument('--world-size', '-s', type=int)
        args = parser.parse_args()
        print(args)
        # initialize(args.rank, args.world_size, args.ip, args.port)
    
        size = 2
        processes = []
        for i in range(size):
            p = Process(target=initialize, args=(i, size, args.ip, args.port))
            p.start()
            processes.append(p)
    
        for p in processes:
            p.join()
    
    
    if __name__ == '__main__':
        main()
    

    注意
    将162.128.0.22换成自己的IP地址。

    共享文件

    import argparse
    from time import sleep
    from random import randint
    from torch.multiprocessing import Process
    
    
    def initialize(rank, world_size):
        dist.init_process_group(backend='gloo', init_method='file:///home/yxk/Documents/Deeplearningoflidar139/overfitover/share', rank=rank, world_size=world_size)
        print("MM")
    
    def main():
    
        size = 2
        processes = []
        for i in range(size):
            p = Process(target=initialize, args=(i, size))
            p.start()
            processes.append(p)
    
        for p in processes:
            p.join()
    
    
    if __name__ == '__main__':
        main()
    

    注意
    init_method: 需要以file://开头,包含共享文件系统上不存在的文件(在现有目录中)的路径。如果文件不存在, 文件系统初始化将自动创建该文件,但不会删除该文件。你要在下一个init_process_group调用之前清楚该文件。

  • 相关阅读:
    SQL 脚本 重复执行 约束
    xiami 精选集
    PHP 5 环境配置
    Thread线程类
    创建线程
    C#中简单的正则表达式(也经常会用到的)
    线程的挂起与恢复
    C#操作INI文件
    多线程简介
    单线程简介
  • 原文地址:https://www.cnblogs.com/o-v-o/p/9975355.html
Copyright © 2011-2022 走看看