zoukankan      html  css  js  c++  java
  • Horovod-Usage

    Usage

    代码中要包含以下6步:

    1. 初始化
    Run hvd.init() to initialize Horovod.
    
    1. 将每个GPU固定到单个进程以避免资源争用。
      一个线程一个GPU,设置到 local rank ,第一个线程将分配给第一个GPU。第二个线程将分配给第二个GPU 向每个 TensorFlow 进程分配一个 GPU
    config = tf.ConfigProto()
    config.gpu_options.visible_device_list = str(hvd.local_rank())
    
    1. 根据worker的数量,来确定学习率
    loss = ...
    opt = tf.train.AdagradOptimizer(0.01 * hvd.size())
    
    1. 使用 Horovod 优化器包裹每一个常规 TensorFlow 优化器,Horovod 优化器使用 ring-allreduce 平均梯度
    opt = hvd.DistributedOptimizer(opt)
    
    1. 将变量从第一个流程向其他流程传播,以实现一致性初始化. 从 rank 0 广播到所有的线程
    hooks = [hvd.BroadcastGlobalVariablesHook(0)]
    
    1. 将checkpoints 保存在worker0上
    with tf.train.MonitoredTrainingSession(checkpoint_dir=checkpoint_dir,
                                           config=config,
                                           hooks=hooks) as mon_sess:
    
    import tensorflow as tf
    import horovod.tensorflow as hvd
    
    
    # Initialize Horovod
    hvd.init()
    
    # Pin GPU to be used to process local rank (one GPU per process)
    config = tf.ConfigProto()
    config.gpu_options.visible_device_list = str(hvd.local_rank())
    
    # Build model...
    loss = ...
    opt = tf.train.AdagradOptimizer(0.01 * hvd.size())
    
    # Add Horovod Distributed Optimizer
    opt = hvd.DistributedOptimizer(opt)
    
    # Add hook to broadcast variables from rank 0 to all other processes during
    # initialization.
    hooks = [hvd.BroadcastGlobalVariablesHook(0)]
    
    # Make training operation
    train_op = opt.minimize(loss)
    
    # Save checkpoints only on worker 0 to prevent other workers from corrupting them.
    checkpoint_dir = '/tmp/train_logs' if hvd.rank() == 0 else None
    
    # The MonitoredTrainingSession takes care of session initialization,
    # restoring from a checkpoint, saving to a checkpoint, and closing when done
    # or an error occurs.
    with tf.train.MonitoredTrainingSession(checkpoint_dir=checkpoint_dir,
                                           config=config,
                                           hooks=hooks) as mon_sess:
      while not mon_sess.should_stop():
        # Perform synchronous training.
        mon_sess.run(train_op)
    
  • 相关阅读:
    【GDOI 2016 Day1】第二题 最长公共子串
    2016.5.21【初中部 NOIP提高组】模拟赛A​ 总结
    【GDOI2014模拟】雨天的尾巴
    树链剖分
    GDOI2016总结
    【GDOI 2016 Day2】第一题 SigemaGO
    【ZJOI2008】树的统计
    【GDOI2016模拟4.22】总结
    【NOIP2016提高A组模拟7.17】寻找
    【NOIP2016提高A组模拟7.17】锦标赛
  • 原文地址:https://www.cnblogs.com/shix0909/p/13391003.html
Copyright © 2011-2022 走看看