zoukankan      html  css  js  c++  java
  • Horovod-Usage

    Usage

    代码中要包含以下6步:

    1. 初始化
    Run hvd.init() to initialize Horovod.
    
    1. 将每个GPU固定到单个进程以避免资源争用。
      一个线程一个GPU,设置到 local rank ,第一个线程将分配给第一个GPU。第二个线程将分配给第二个GPU 向每个 TensorFlow 进程分配一个 GPU
    config = tf.ConfigProto()
    config.gpu_options.visible_device_list = str(hvd.local_rank())
    
    1. 根据worker的数量,来确定学习率
    loss = ...
    opt = tf.train.AdagradOptimizer(0.01 * hvd.size())
    
    1. 使用 Horovod 优化器包裹每一个常规 TensorFlow 优化器,Horovod 优化器使用 ring-allreduce 平均梯度
    opt = hvd.DistributedOptimizer(opt)
    
    1. 将变量从第一个流程向其他流程传播,以实现一致性初始化. 从 rank 0 广播到所有的线程
    hooks = [hvd.BroadcastGlobalVariablesHook(0)]
    
    1. 将checkpoints 保存在worker0上
    with tf.train.MonitoredTrainingSession(checkpoint_dir=checkpoint_dir,
                                           config=config,
                                           hooks=hooks) as mon_sess:
    
    import tensorflow as tf
    import horovod.tensorflow as hvd
    
    
    # Initialize Horovod
    hvd.init()
    
    # Pin GPU to be used to process local rank (one GPU per process)
    config = tf.ConfigProto()
    config.gpu_options.visible_device_list = str(hvd.local_rank())
    
    # Build model...
    loss = ...
    opt = tf.train.AdagradOptimizer(0.01 * hvd.size())
    
    # Add Horovod Distributed Optimizer
    opt = hvd.DistributedOptimizer(opt)
    
    # Add hook to broadcast variables from rank 0 to all other processes during
    # initialization.
    hooks = [hvd.BroadcastGlobalVariablesHook(0)]
    
    # Make training operation
    train_op = opt.minimize(loss)
    
    # Save checkpoints only on worker 0 to prevent other workers from corrupting them.
    checkpoint_dir = '/tmp/train_logs' if hvd.rank() == 0 else None
    
    # The MonitoredTrainingSession takes care of session initialization,
    # restoring from a checkpoint, saving to a checkpoint, and closing when done
    # or an error occurs.
    with tf.train.MonitoredTrainingSession(checkpoint_dir=checkpoint_dir,
                                           config=config,
                                           hooks=hooks) as mon_sess:
      while not mon_sess.should_stop():
        # Perform synchronous training.
        mon_sess.run(train_op)
    
  • 相关阅读:
    socket编程原理
    配置Symbian WINS Emulator
    mysql 的乱码解决方法
    深入剖析关于JSP和Servlet对中文的处理
    一个分众传媒业务员的销售日记
    中移动第四季度SP评级结果出炉 A级企业仅5家
    基于socket的聊天室实现原理
    看Linux内核源码 练内力必备技能
    Dell要收购AMD?
    同步执行其他程序(dos命令)
  • 原文地址:https://www.cnblogs.com/shix0909/p/13391003.html
Copyright © 2011-2022 走看看