zoukankan      html  css  js  c++  java
  • tensorflow使用horovod进行多gpu训练

    tensorflow使用horovod多gpu训练

    要使用Horovod,在程序中添加以下内容。此示例使用TensorFlow。

    1. 运行hvd.init()

    1. 使用固定服务器GPU,以供此过程使用config.gpu_options.visible_device_list

      通过每个进程一个GPU的典型设置,您可以将其设置为local rank在这种情况下,服务器上的第一个进程将被分配第一GPU,第二个进程将被分配第二GPU,依此类推。

    1. 通过工人人数来衡量学习率

      同步分布式培训中的有效批处理规模是根据工人人数来衡量的。学习率的提高弥补了批量大小的增加。

    1. 将优化器包装在中hvd.DistributedOptimizer

      分布式优化器将梯度计算委派给原始优化器,使用allreduceallgather对梯度平均,然后应用这些平均梯度。

    1. 添加hvd.BroadcastGlobalVariablesHook(0)到播放初始变量状态从0级到所有其他进程

      当使用随机权重开始训练或从检查点恢复训练时,这是确保所有工人进行一致初始化的必要步骤。另外,如果您不使用MonitoredTrainingSession,则可以hvd.broadcast_global_variables在初始化全局变量之后执行op。

    1. 修改您的代码以仅在工作程序0上保存检查点,以防止其他工作程序破坏它们

      通过传递checkpoint_dir=Nonetf.train.MonitoredTrainingSessionif 完成此操作hvd.rank() != 0

    简单示例代码

     1 import tensorflow as tf
     2 import horovod.tensorflow as hvd
     3 
     4 
     5 # Initialize Horovod
     6 hvd.init()
     7 
     8 # Pin GPU to be used to process local rank (one GPU per process)
     9 config = tf.ConfigProto()
    10 config.gpu_options.visible_device_list = str(hvd.local_rank())
    11 
    12 # Build model...
    13 loss = ...
    14 opt = tf.train.AdagradOptimizer(0.01 * hvd.size())
    15 
    16 # Add Horovod Distributed Optimizer
    17 opt = hvd.DistributedOptimizer(opt)
    18 
    19 # Add hook to broadcast variables from rank 0 to all other processes during
    20 # initialization.
    21 hooks = [hvd.BroadcastGlobalVariablesHook(0)]
    22 
    23 # Make training operation
    24 train_op = opt.minimize(loss)
    25 
    26 # Save checkpoints only on worker 0 to prevent other workers from corrupting them.
    27 checkpoint_dir = '/tmp/train_logs' if hvd.rank() == 0 else None
    28 
    29 # The MonitoredTrainingSession takes care of session initialization,
    30 # restoring from a checkpoint, saving to a checkpoint, and closing when done
    31 # or an error occurs.
    32 with tf.train.MonitoredTrainingSession(checkpoint_dir=checkpoint_dir,
    33                                        config=config,
    34                                        hooks=hooks) as mon_sess:
    35   while not mon_sess.should_stop():
    36     # Perform synchronous training.
    37     mon_sess.run(train_op)
  • 相关阅读:
    读取组合单元格
    Spire.XLS:一款Excel处理神器
    linq
    LINQ语句中的.AsEnumerable() 和 .AsQueryable()的区别
    合并单元格
    web sec / ssd / sshd
    linux——cat之查看cpu信息、显示终端、校验内存.............
    MATLAB mcr lib的环境变量书写
    Linux查看库依赖方法
    判断当前所使用python的版本和来源
  • 原文地址:https://www.cnblogs.com/ywheunji/p/12298531.html
Copyright © 2011-2022 走看看