zoukankan      html  css  js  c++  java
  • 学习笔记TF061:分布式TensorFlow,分布式原理、最佳实践

    分布式TensorFlow由高性能gRPC库底层技术支持。Martin Abadi、Ashish Agarwal、Paul Barham论文《TensorFlow:Large-Scale Machine Learning on Heterogeneous Distributed Systems》。

    分布式原理。分布式集群 由多个服务器进程、客户端进程组成。部署方式,单机多卡、分布式(多机多卡)。多机多卡TensorFlow分布式。

    单机多卡,单台服务器多块GPU。训练过程:在单机单GPU训练,数据一个批次(batch)一个批次训练。单机多GPU,一次处理多个批次数据,每个GPU处理一个批次数据计算。变量参数保存在CPU,数据由CPU分发给多个GPU,GPU计算每个批次更新梯度。CPU收集完多个GPU更新梯度,计算平均梯度,更新参数。继续计算更新梯度。处理速度取决最慢GPU速度。

    分布式,训练在多个工作节点(worker)。工作节点,实现计算单元。计算服务器单卡,指服务器。计算服务器多卡,多个GPU划分多个工作节点。数据量大,超过一台机器处理能力,须用分布式。

    分布式TensorFlow底层通信,gRPC(google remote procedure call)。gRPC,谷歌开源高性能、跨语言RPC框架。RPC协议,远程过程调用协议,网络从远程计算机程度请求服务。

    分布式部署方式。分布式运行,多个计算单元(工作节点),后端服务器部署单工作节点、多工作节点。

    单工作节点部署。每台服务器运行一个工作节点,服务器多个GPU,一个工作节点可以访问多块GPU卡。代码tf.device()指定运行操作设备。优势,单机多GPU间通信,效率高。劣势,手动代码指定设备。

    多工作节点部署。一台服务器运行多个工作节点。

    设置CUDA_VISIBLE_DEVICES环境变量,限制各个工作节点只可见一个GPU,启动进程添加环境变量。用tf.device()指定特定GPU。多工作节点部署优势,代码简单,提高GPU使用率。劣势,工作节点通信,需部署多个工作节点。https://github.com/tobegit3hub/tensorflow_examples/tree/master/distributed_tensorflow 。

    CUDA_VISIBLE_DEVICES='' python ./distributed_supervisor.py --ps_hosts=127.0.0.1:2222,127.0.0.1:2223 --worker_hosts=127.0.0.1:2224,127.0.0.1:2225 --job_name=ps --task_index=0
    CUDA_VISIBLE_DEVICES='' python ./distributed_supervisor.py --ps_hosts=127.0.0.1:2222,127.0.0.1:2223 --worker_hosts=127.0.0.1:2224,127.0.0.1:2225 --job_name=ps --task_index=1
    CUDA_VISIBLE_DEVICES='0' python ./distributed_supervisor.py --ps_hosts=127.0.0.1:2222,127.0.0.1:2223 --worker_hosts=127.0.0.1:2224,127.0.0.1:2225 --job_name=worker --task_index=0
    CUDA_VISIBLE_DEVICES='1' python ./distributed_supervisor.py --ps_hosts=127.0.0.1:2222,127.0.0.1:2223 --worker_hosts=127.0.0.1:2224,127.0.0.1:2225 --job_name=worker --task_index=1

    分布式架构。https://www.tensorflow.org/extend/architecture 。客户端(client)、服务端(server),服务端包括主节点(master)、工作节点(worker)组成。

    客户端、主节点、工作节点关系。TensorFlow,客户端会话联系主节点,实际工作由工作节点实现,每个工作节点占一台设备(TensorFlow具体计算硬件抽象,CPU或GPU)。单机模式,客户端、主节点、工作节点在同一台服务器。分布模式,可不同服务器。客户端->主节点->工作节点/job:worker/task:0->/job:ps/task:0。
    客户端。建立TensorFlow计算图,建立与集群交互会话层。代码包含Session()。一个客户端可同时与多个服务端相连,一具服务端也可与多个客户端相连。
    服务端。运行tf.train.Server实例进程,TensroFlow执行任务集群(cluster)一部分。有主节点服务(Master service)和工作节点服务(Worker service)。运行中,一个主节点进程和数个工作节点进程,主节点进程和工作接点进程通过接口通信。单机多卡和分布式结构相同,只需要更改通信接口实现切换。
    主节点服务。实现tensorflow::Session接口。通过RPC服务程序连接工作节点,与工作节点服务进程工作任务通信。TensorFlow服务端,task_index为0作业(job)。
    工作节点服务。实现worker_service.proto接口,本地设备计算部分图。TensorFlow服务端,所有工作节点包含工作节点服务逻辑。每个工作节点负责管理一个或多个设备。工作节点可以是本地不同端口不同进程,或多台服务多个进程。运行TensorFlow分布式执行任务集,一个或多个作业(job)。每个作业,一个或多个相同目的任务(task)。每个任务,一个工作进程执行。作业是任务集合,集群是作业集合。
    分布式机器学习框架,作业分参数作业(parameter job)和工作节点作业(worker job)。参数作业运行服务器为参数服务器(parameter server,PS),管理参数存储、更新。工作节点作业,管理无状态主要从事计算任务。模型越大,参数越多,模型参数更新超过一台机器性能,需要把参数分开到不同机器存储更新。参数服务,多台机器组成集群,类似分布式存储架构,涉及数据同步、一致性,参数存储为键值对(key-value)。分布式键值内存数据库,加参数更新操作。李沐《Parameter Server for Distributed Machine Learning》http://www.cs.cmu.edu/~muli/file/ps.pdf 。
    参数存储更新在参数作业进行,模型计算在工作节点作业进行。TensorFlow分布式实现作业间数据传输,参数作业到工作节点作业前向传播,工作节点作业到参数作业反向传播。
    任务。特定TensorFlow服务器独立进程,在作业中拥有对应序号。一个任务对应一个工作节点。集群->作业->任务->工作节点。

    客户端、主节点、工作节点交互过程。单机多卡交互,客户端->会话运行->主节点->执行子图->工作节点->GPU0、GPU1。分布式交互,客户端->会话运行->主节点进程->执行子图1->工作节点进程1->GPU0、GPU1。《TensorFlow:Large-Scale Machine Learning on Heterogeneous distributed Systems》https://arxiv.org/abs/1603.04467v1 。

    分布式模式。

    数据并行。https://www.tensorflow.org/tutorials/deep_cnn 。CPU负责梯度平均、参数更新,不同GPU训练模型副本(model replica)。基于训练样例子集训练,模型有独立性。
    步骤:不同GPU分别定义模型网络结构。单个GPU从数据管道读取不同数据块,前向传播,计算损失,计算当前变量梯度。所有GPU输出梯度数据转移到CPU,梯度求平均操作,模型变量更新。重复,直到模型变量收敛。
    数据并行,提高SGD效率。SGD mini-batch样本,切成多份,模型复制多份,在多个模型上同时计算。多个模型计算速度不一致,CPU更新变量有同步、异步两个方案。

    同步更新、异步更新。分布式随机梯度下降法,模型参数分布式存储在不同参数服务上,工作节点并行训练数据,和参数服务器通信获取模型参数。
    同步随机梯度下降法(Sync-SGD,同步更新、同步训练),训练时,每个节点上工作任务读入共享参数,执行并行梯度计算,同步需要等待所有工作节点把局部梯度处好,将所有共享参数合并、累加,再一次性更新到模型参数,下一批次,所有工作节点用模型更新后参数训练。优势,每个训练批次考虑所有工作节点训练情部,损失下降稳定。劣势,性能瓶颈在最慢工作节点。异楹设备,工作节点性能不同,劣势明显。
    异步随机梯度下降法(Async-SGD,异步更新、异步训练),每个工作节点任务独立计算局部梯度,异步更新到模型参数,不需执行协调、等待操作。优势,性能不存在瓶颈。劣势,每个工作节点计算梯度值发磅回参数服务器有参数更新冲突,影响算法收剑速度,损失下降过程抖动较大。
    同步更新、异步更新实现区别于更新参数服务器参数策略。数据量小,各节点计算能力较均衡,用同步模型。数据量大,各机器计算性能参差不齐,用异步模式。
    带备份的Sync-SGD(Sync-SDG with backup)。Jianmin Chen、Xinghao Pan、Rajat Monga、Aamy Bengio、Rafal Jozefowicz论文《Revisiting Distributed Synchronous SGD》https://arxiv.org/abs/1604.00981 。增加工作节点,解决部分工作节点计算慢问题。工作节点总数n+n*5%,n为集群工作节点数。异步更新设定接受到n个工作节点参数直接更新参数服务器模型参数,进入下一批次模型训练。计算较慢节点训练参数直接丢弃。
    同步更新、异步更新有图内模式(in-graph pattern)和图间模式(between-graph pattern),独立于图内(in-graph)、图间(between-graph)概念。
    图内复制(in-grasph replication),所有操作(operation)在同一个图中,用一个客户端来生成图,把所有操作分配到集群所有参数服务器和工作节点上。国内复制和单机多卡类似,扩展到多机多卡,数据分发还是在客户端一个节点上。优势,计算节点只需要调用join()函数等待任务,客户端随时提交数据就可以训练。劣势,训练数据分发在一个节点上,要分发给不同工作节点,严重影响并发训练速度。
    图间复制(between-graph replication),每一个工作节点创建一个图,训练参数保存在参数服务器,数据不分发,各个工作节点独立计算,计算完成把要更新参数告诉参数服务器,参数服务器更新参数。优势,不需要数据分发,各个工作节点都创建图和读取数据训练。劣势,工作节点既是图创建者又是计算任务执行者,某个工作节点宕机影响集群工作。大数据相关深度学习推荐使用图间模式。

    模型并行。切分模型,模型不同部分执行在不同设备上,一个批次样本可以在不同设备同时执行。TensorFlow尽量让相邻计算在同一台设备上完成节省网络开销。Martin Abadi、Ashish Agarwal、Paul Barham论文《TensorFlow:Large-Scale Machine Learning on Heterogeneous Distributed Systems》https://arxiv.org/abs/1603.04467v1 。

    模型并行、数据并行,TensorFlow中,计算可以分离,参数可以分离。可以在每个设备上分配计算节点,让对应参数也在该设备上,计算参数放一起。

    分布式API。https://www.tensorflow.org/deploy/distributed 。
    创建集群,每个任务(task)启动一个服务(工作节点服务或主节点服务)。任务可以分布不同机器,可以同一台机器启动多个任务,用不同GPU运行。每个任务完成工作:创建一个tf.train.ClusterSpec,对集群所有任务进行描述,描述内容对所有任务相同。创建一个tf.train.Server,创建一个服务,运行相应作业计算任务。
    TensorFlow分布式开发API。tf.train.ClusterSpec({"ps":ps_hosts,"worker":worke_hosts})。创建TensorFlow集群描述信息,ps、worker为作业名称,ps_phsts、worker_hosts为作业任务所在节点地址信息。tf.train.ClusterSpec传入参数,作业和任务间关系映射,映射关系任务通过IP地址、端口号表示。

    结构 tf.train.ClusterSpec({"local":["localhost:2222","localhost:2223"]})
    可用任务 /job:local/task:0、/job:local/task:1。
    结构 tf.train.ClusterSpec({"worker":["worker0.example.com:2222","worker1.example.com:2222","worker2.example.com:2222"],"ps":["ps0.example.com:2222","ps1.example.com:2222"]})
    可用任务 /job:worker/task:0、 /job:worker/task:1、 /job:worker/task:2、 /job:ps/task:0、 /job:ps/task:1
    tf.train.Server(cluster,job_name,task_index)。创建服务(主节点服务或工作节点服务),运行作业计算任务,运行任务在task_index指定机器启动。

    #任务0
    cluster = tr.train.ClusterSpec({"local":["localhost:2222","localhost:2223"]})
    server = tr.train.Server(cluster,job_name="local",task_index=0)
    #任务1
    cluster = tr.train.ClusterSpec({"local":["localhost:2222","localhost:2223"]})
    server = tr.train.Server(cluster,job_name="local",task_index=1)。
    自动化管理节点、监控节点工具。集群管理工具Kubernetes。
    tf.device(device_name_or_function)。设定指定设备执行张量运算,批定代码运行CPU、GPU。

    #指定在task0所在机器执行Tensor操作运算
    with tf.device("/job:ps/task:0"):
    weights_1 = tf.Variable(…)
    biases_1 = tf.Variable(…)

    分布式训练代码框架。创建TensorFlow服务器集群,在该集群分布式计算数据流图。https://github.com/tensorflow/tensorflow/blob/master/tensorflow/docs_src/deploy/distributed.md 。

    import argparse
    import sys
    import tensorflow as tf
    FLAGS = None
    def main(_):
      # 第1步:命令行参数解析,获取集群信息ps_hosts、worker_hosts
      # 当前节点角色信息job_name、task_index
      ps_hosts = FLAGS.ps_hosts.split(",")
      worker_hosts = FLAGS.worker_hosts.split(",")
      # 第2步:创建当前任务节点服务器
      # Create a cluster from the parameter server and worker hosts.
      cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})
      # Create and start a server for the local task.
      server = tf.train.Server(cluster,
                               job_name=FLAGS.job_name,
                               task_index=FLAGS.task_index)
      # 第3步:如果当前节点是参数服务器,调用server.join()无休止等待;如果是工作节点,执行第4步
      if FLAGS.job_name == "ps":
        server.join()
      # 第4步:构建要训练模型,构建计算图
      elif FLAGS.job_name == "worker":
        # Assigns ops to the local worker by default.
        with tf.device(tf.train.replica_device_setter(
            worker_device="/job:worker/task:%d" % FLAGS.task_index,
            cluster=cluster)):
          # Build model...
          loss = ...
          global_step = tf.contrib.framework.get_or_create_global_step()
          train_op = tf.train.AdagradOptimizer(0.01).minimize(
              loss, global_step=global_step)
        # The StopAtStepHook handles stopping after running given steps.
        # 第5步管理模型训练过程
        hooks=[tf.train.StopAtStepHook(last_step=1000000)]
        # The MonitoredTrainingSession takes care of session initialization,
        # restoring from a checkpoint, saving to a checkpoint, and closing when done
        # or an error occurs.
        with tf.train.MonitoredTrainingSession(master=server.target,
                                               is_chief=(FLAGS.task_index == 0),
                                               checkpoint_dir="/tmp/train_logs",
                                               hooks=hooks) as mon_sess:
          while not mon_sess.should_stop():
            # Run a training step asynchronously.
            # See `tf.train.SyncReplicasOptimizer` for additional details on how to
            # perform *synchronous* training.
            # mon_sess.run handles AbortedError in case of preempted PS.
            # 训练模型
            mon_sess.run(train_op)
    if __name__ == "__main__":
      parser = argparse.ArgumentParser()
      parser.register("type", "bool", lambda v: v.lower() == "true")
      # Flags for defining the tf.train.ClusterSpec
      parser.add_argument(
          "--ps_hosts",
          type=str,
          default="",
          help="Comma-separated list of hostname:port pairs"
      )
      parser.add_argument(
          "--worker_hosts",
          type=str,
          default="",
          help="Comma-separated list of hostname:port pairs"
      )
      parser.add_argument(
          "--job_name",
          type=str,
          default="",
          help="One of 'ps', 'worker'"
      )
      # Flags for defining the tf.train.Server
      parser.add_argument(
          "--task_index",
          type=int,
          default=0,
          help="Index of task within the job"
      )
      FLAGS, unparsed = parser.parse_known_args()
      tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

    分布式最佳实践。https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/dist_test/python/mnist_replica.py 。

    MNIST数据集分布式训练。开设3个端口作分布式工作节点部署,2222端口参数服务器,2223端口工作节点0,2224端口工作节点1。参数服务器执行参数更新任务,工作节点0、工作节点1执行图模型训练计算任务。参数服务器/job:ps/task:0 cocalhost:2222,工作节点/job:worker/task:0 cocalhost:2223,工作节点/job:worker/task:1 cocalhost:2224。
    运行代码。

    python mnist_replica.py --job_name="ps" --task_index=0
    python mnist_replica.py --job_name="worker" --task_index=0
    python mnist_replica.py --job_name="worker" --task_index=1
    
    from __future__ import absolute_import
    from __future__ import division
    from __future__ import print_function
    import math
    import sys
    import tempfile
    import time
    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data
    # 定义常量,用于创建数据流图
    flags = tf.app.flags
    flags.DEFINE_string("data_dir", "/tmp/mnist-data",
                        "Directory for storing mnist data")
    # 只下载数据,不做其他操作
    flags.DEFINE_boolean("download_only", False,
                         "Only perform downloading of data; Do not proceed to "
                         "session preparation, model definition or training")
    # task_index从0开始。0代表用来初始化变量的第一个任务
    flags.DEFINE_integer("task_index", None,
                         "Worker task index, should be >= 0. task_index=0 is "
                         "the master worker task the performs the variable "
                         "initialization ")
    # 每台机器GPU个数,机器没有GPU为0
    flags.DEFINE_integer("num_gpus", 1,
                         "Total number of gpus for each machine."
                         "If you don't use GPU, please set it to '0'")
    # 同步训练模型下,设置收集工作节点数量。默认工作节点总数
    flags.DEFINE_integer("replicas_to_aggregate", None,
                         "Number of replicas to aggregate before parameter update"
                         "is applied (For sync_replicas mode only; default: "
                         "num_workers)")
    flags.DEFINE_integer("hidden_units", 100,
                         "Number of units in the hidden layer of the NN")
    # 训练次数
    flags.DEFINE_integer("train_steps", 200,
                         "Number of (global) training steps to perform")
    flags.DEFINE_integer("batch_size", 100, "Training batch size")
    flags.DEFINE_float("learning_rate", 0.01, "Learning rate")
    # 使用同步训练、异步训练
    flags.DEFINE_boolean("sync_replicas", False,
                         "Use the sync_replicas (synchronized replicas) mode, "
                         "wherein the parameter updates from workers are aggregated "
                         "before applied to avoid stale gradients")
    # 如果服务器已经存在,采用gRPC协议通信;如果不存在,采用进程间通信
    flags.DEFINE_boolean(
        "existing_servers", False, "Whether servers already exists. If True, "
        "will use the worker hosts via their GRPC URLs (one client process "
        "per worker host). Otherwise, will create an in-process TensorFlow "
        "server.")
    # 参数服务器主机
    flags.DEFINE_string("ps_hosts","localhost:2222",
                        "Comma-separated list of hostname:port pairs")
    # 工作节点主机
    flags.DEFINE_string("worker_hosts", "localhost:2223,localhost:2224",
                        "Comma-separated list of hostname:port pairs")
    # 本作业是工作节点还是参数服务器
    flags.DEFINE_string("job_name", None,"job name: worker or ps")
    FLAGS = flags.FLAGS
    IMAGE_PIXELS = 28
    def main(unused_argv):
      mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
      if FLAGS.download_only:
        sys.exit(0)
      if FLAGS.job_name is None or FLAGS.job_name == "":
        raise ValueError("Must specify an explicit `job_name`")
      if FLAGS.task_index is None or FLAGS.task_index =="":
        raise ValueError("Must specify an explicit `task_index`")
      print("job name = %s" % FLAGS.job_name)
      print("task index = %d" % FLAGS.task_index)
      #Construct the cluster and start the server
      # 读取集群描述信息
      ps_spec = FLAGS.ps_hosts.split(",")
      worker_spec = FLAGS.worker_hosts.split(",")
      # Get the number of workers.
      num_workers = len(worker_spec)
      # 创建TensorFlow集群描述对象
      cluster = tf.train.ClusterSpec({
          "ps": ps_spec,
          "worker": worker_spec})
      # 为本地执行任务创建TensorFlow Server对象。
      if not FLAGS.existing_servers:
        # Not using existing servers. Create an in-process server.
        # 创建本地Sever对象,从tf.train.Server这个定义开始,每个节点开始不同
        # 根据执行的命令的参数(作业名字)不同,决定这个任务是哪个任务
        # 如果作业名字是ps,进程就加入这里,作为参数更新的服务,等待其他工作节点给它提交参数更新的数据
        # 如果作业名字是worker,就执行后面的计算任务
        server = tf.train.Server(
            cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index)
        # 如果是参数服务器,直接启动即可。这里,进程就会阻塞在这里
        # 下面的tf.train.replica_device_setter代码会将参数批定给ps_server保管
        if FLAGS.job_name == "ps":
          server.join()
      # 处理工作节点
      # 找出worker的主节点,即task_index为0的点
      is_chief = (FLAGS.task_index == 0)
      # 如果使用gpu
      if FLAGS.num_gpus > 0:
        # Avoid gpu allocation conflict: now allocate task_num -> #gpu
        # for each worker in the corresponding machine
        gpu = (FLAGS.task_index % FLAGS.num_gpus)
        # 分配worker到指定gpu上运行
        worker_device = "/job:worker/task:%d/gpu:%d" % (FLAGS.task_index, gpu)
      # 如果使用cpu
      elif FLAGS.num_gpus == 0:
        # Just allocate the CPU to worker server
        # 把cpu分配给worker
        cpu = 0
        worker_device = "/job:worker/task:%d/cpu:%d" % (FLAGS.task_index, cpu)
      # The device setter will automatically place Variables ops on separate
      # parameter servers (ps). The non-Variable ops will be placed on the workers.
      # The ps use CPU and workers use corresponding GPU
      # 用tf.train.replica_device_setter将涉及变量操作分配到参数服务器上,使用CPU。将涉及非变量操作分配到工作节点上,使用上一步worker_device值。
      # 在这个with语句之下定义的参数,会自动分配到参数服务器上去定义。如果有多个参数服务器,就轮流循环分配
      with tf.device(
          tf.train.replica_device_setter(
              worker_device=worker_device,
              ps_device="/job:ps/cpu:0",
              cluster=cluster)):
    
        # 定义全局步长,默认值为0
        global_step = tf.Variable(0, name="global_step", trainable=False)
        # Variables of the hidden layer
        # 定义隐藏层参数变量,这里是全连接神经网络隐藏层
        hid_w = tf.Variable(
            tf.truncated_normal(
                [IMAGE_PIXELS * IMAGE_PIXELS, FLAGS.hidden_units],
                stddev=1.0 / IMAGE_PIXELS),
            name="hid_w")
        hid_b = tf.Variable(tf.zeros([FLAGS.hidden_units]), name="hid_b")
        # Variables of the softmax layer
        # 定义Softmax 回归层参数变量
        sm_w = tf.Variable(
            tf.truncated_normal(
                [FLAGS.hidden_units, 10],
                stddev=1.0 / math.sqrt(FLAGS.hidden_units)),
            name="sm_w")
        sm_b = tf.Variable(tf.zeros([10]), name="sm_b")
        # Ops: located on the worker specified with FLAGS.task_index
        # 定义模型输入数据变量
        x = tf.placeholder(tf.float32, [None, IMAGE_PIXELS * IMAGE_PIXELS])
        y_ = tf.placeholder(tf.float32, [None, 10])
        # 构建隐藏层
        hid_lin = tf.nn.xw_plus_b(x, hid_w, hid_b)
        hid = tf.nn.relu(hid_lin)
        # 构建损失函数和优化器
        y = tf.nn.softmax(tf.nn.xw_plus_b(hid, sm_w, sm_b))
        cross_entropy = -tf.reduce_sum(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0)))
        # 异步训练模式:自己计算完成梯度就去更新参数,不同副本之间不会去协调进度
        opt = tf.train.AdamOptimizer(FLAGS.learning_rate)
        # 同步训练模式
        if FLAGS.sync_replicas:
          if FLAGS.replicas_to_aggregate is None:
            replicas_to_aggregate = num_workers
          else:
            replicas_to_aggregate = FLAGS.replicas_to_aggregate
          # 使用SyncReplicasOptimizer作优化器,并且是在图间复制情况下
          # 在图内复制情况下将所有梯度平均
          opt = tf.train.SyncReplicasOptimizer(
              opt,
              replicas_to_aggregate=replicas_to_aggregate,
              total_num_replicas=num_workers,
              name="mnist_sync_replicas")
        train_step = opt.minimize(cross_entropy, global_step=global_step)
        if FLAGS.sync_replicas:
          local_init_op = opt.local_step_init_op
          if is_chief:
            # 所有进行计算工作节点里一个主工作节点(chief)
            # 主节点负责初始化参数、模型保存、概要保存
            local_init_op = opt.chief_init_op
          ready_for_local_init_op = opt.ready_for_local_init_op
          # Initial token and chief queue runners required by the sync_replicas mode
          # 同步训练模式所需初始令牌、主队列
          chief_queue_runner = opt.get_chief_queue_runner()
          sync_init_op = opt.get_init_tokens_op()
        init_op = tf.global_variables_initializer()
        train_dir = tempfile.mkdtemp()
        if FLAGS.sync_replicas:
          # 创建一个监管程序,用于统计训练模型过程中的信息
          # lodger 是保存和加载模型路径
          # 启动就会去这个logdir目录看是否有检查点文件,有的话就自动加载
          # 没有就用init_op指定初始化参数
          # 主工作节点(chief)负责模型参数初始化工作
          # 过程中,其他工作节点等待主节眯完成初始化工作,初始化完成后,一起开始训练数据
          # global_step值是所有计算节点共享的
          # 在执行损失函数最小值时自动加1,通过global_step知道所有计算节点一共计算多少步
          sv = tf.train.Supervisor(
              is_chief=is_chief,
              logdir=train_dir,
              init_op=init_op,
              local_init_op=local_init_op,
              ready_for_local_init_op=ready_for_local_init_op,
              recovery_wait_secs=1,
              global_step=global_step)
        else:
          sv = tf.train.Supervisor(
              is_chief=is_chief,
              logdir=train_dir,
              init_op=init_op,
              recovery_wait_secs=1,
              global_step=global_step)
        # 创建会话,设置属性allow_soft_placement为True
        # 所有操作默认使用被指定设置,如GPU
        # 如果该操作函数没有GPU实现,自动使用CPU设备
        sess_config = tf.ConfigProto(
            allow_soft_placement=True,
            log_device_placement=False,
            device_filters=["/job:ps", "/job:worker/task:%d" % FLAGS.task_index])
        # The chief worker (task_index==0) session will prepare the session,
        # while the remaining workers will wait for the preparation to complete.
        # 主工作节点(chief),task_index为0节点初始化会话
        # 其余工作节点等待会话被初始化后进行计算
        if is_chief:
          print("Worker %d: Initializing session..." % FLAGS.task_index)
        else:
          print("Worker %d: Waiting for session to be initialized..." %
                FLAGS.task_index)
        if FLAGS.existing_servers:
          server_grpc_url = "grpc://" + worker_spec[FLAGS.task_index]
          print("Using existing server at: %s" % server_grpc_url)
          # 创建TensorFlow会话对象,用于执行TensorFlow图计算
          # prepare_or_wait_for_session需要参数初始化完成且主节点准备好后,才开始训练
          sess = sv.prepare_or_wait_for_session(server_grpc_url,
                                                config=sess_config)
        else:
          sess = sv.prepare_or_wait_for_session(server.target, config=sess_config)
        print("Worker %d: Session initialization complete." % FLAGS.task_index)
        if FLAGS.sync_replicas and is_chief:
          # Chief worker will start the chief queue runner and call the init op.
          sess.run(sync_init_op)
          sv.start_queue_runners(sess, [chief_queue_runner])
        # Perform training
        # 执行分布式模型训练
        time_begin = time.time()
        print("Training begins @ %f" % time_begin)
        local_step = 0
        while True:
          # Training feed
          # 读入MNIST训练数据,默认每批次100张图片
          batch_xs, batch_ys = mnist.train.next_batch(FLAGS.batch_size)
          train_feed = {x: batch_xs, y_: batch_ys}
          _, step = sess.run([train_step, global_step], feed_dict=train_feed)
          local_step += 1
          now = time.time()
          print("%f: Worker %d: training step %d done (global step: %d)" %
                (now, FLAGS.task_index, local_step, step))
          if step >= FLAGS.train_steps:
            break
        time_end = time.time()
        print("Training ends @ %f" % time_end)
        training_time = time_end - time_begin
        print("Training elapsed time: %f s" % training_time)
        # Validation feed
        # 读入MNIST验证数据,计算验证的交叉熵
        val_feed = {x: mnist.validation.images, y_: mnist.validation.labels}
        val_xent = sess.run(cross_entropy, feed_dict=val_feed)
        print("After %d training step(s), validation cross entropy = %g" %
              (FLAGS.train_steps, val_xent))
    if __name__ == "__main__":
      tf.app.run()

    参考资料:

    《TensorFlow技术解析与实战》

    欢迎推荐上海机器学习工作机会,我的微信:qingxingfengzi

  • 相关阅读:
    Spring Boot (20) 拦截器
    Spring Boot (19) servlet、filter、listener
    Spring Boot (18) @Async异步
    Spring Boot (17) 发送邮件
    Spring Boot (16) logback和access日志
    Spring Boot (15) pom.xml设置
    Spring Boot (14) 数据源配置原理
    Spring Boot (13) druid监控
    Spring boot (12) tomcat jdbc连接池
    Spring Boot (11) mybatis 关联映射
  • 原文地址:https://www.cnblogs.com/libinggen/p/7814083.html
Copyright © 2011-2022 走看看