zoukankan      html  css  js  c++  java
  • TensorFlow Distribution(分布式中的数据读取和训练)

    本文目的

    在介绍estimator分布式的时候,官方文档由于版本更新导致与接口不一致。具体是:在estimator分布式当中,使用dataset作为数据输入,在1.12版本中,数据训练只是dataset的数据,就是所有设备加起来,跑一遍数据。

    而在2.0版本中,训练数据是dataset的数据乘以分
    布式的设备数。也就是说,在每个设备当中都会完整地跑一遍dataset的所有数据。

    1.12版本读取

    1. 在主线程当中创建图

    下面这段代码中,在client中调用了input function,得到迭代器。这是属于estimator distribute train调用的代码

    with ops.Graph().as_default() as g:
          # We want to create the iterations variable outside the distribution scope
          # as that is just stored on the host and mainly used to drive the loop
          # and doesn't need to be a Mirrored/Device variable.
          if is_tpu_strategy:
            steps_per_run_variable = training.get_or_create_steps_per_run_variable()
          with self._train_distribution.scope():
            random_seed.set_random_seed(self._config.tf_random_seed)
            iterator, input_hooks = self._get_iterator_from_input_fn(
                input_fn, model_fn_lib.ModeKeys.TRAIN, self._train_distribution)
    
    • _get_iterator_from_input_fn * 这个函数会生成迭代器供后续训练读取数据。
      def _get_iterator_from_input_fn(self, input_fn, mode, distribution=None):
        if distribution is not None:
          result = distribution.distribute_dataset(
              lambda: self._call_input_fn(input_fn, mode))
        else:
          result = self._call_input_fn(input_fn, mode)
    
        iterator = result.make_initializable_iterator()
        input_hooks = [estimator_util._DatasetInitializerHook(iterator)]  # pylint: disable=protected-access
        return iterator, input_hooks
    

    这里会调用distribute_dataset生成dataset。
    再点进去看以后可看到会创建这样一个PerDeviceDataset

    class PerDeviceDataset(object):
      """Like `tf.data.Dataset` split devices, producing `PerDevice` data."""
    
      def __init__(self, dataset, devices, prefetch_on_device=None):
        self._devices = devices
    
        # Default to using prefetching in graph mode, unless specified.
        # TODO(priyag): Enable prefetching in eager mode.
        self._prefetch_on_device = prefetch_on_device
        if self._prefetch_on_device is None:
          self._prefetch_on_device = not context.executing_eagerly()
        assert not (self._prefetch_on_device and context.executing_eagerly()), (
            "Prefetching is only supported in graph mode currently")
    
        if self._prefetch_on_device:
          self._dataset = dataset.apply(
              prefetching_ops_v2.prefetch_to_devices(self._devices))
        else:
          # TODO(priyag): If dropping remainder is not appropriate, find another
          # approach to distributing the dataset when not possible to divide evenly.
          # Possibly not an issue when we start using PartitionedDataset.
          self._dataset = dataset.batch(len(devices), drop_remainder=True)
    

    最后一行代码可以看到,在原dataset上又封装了一层batch。将数据根据设备数切分。
    后面创建迭代器也是封装为PerDeviceDataIterator,形成一个字典映射,不同设备不同数据,根据batch 的index切分。

    分布式训练

    在1.12版本中的训练比较简单。对于MirroredStrategy来说,会给每个一个device创建一个线程,
    有一个缺点就是,每一次run都会创建线程,在todo里看到,后续会优化掉应该。

    下面是在client中从迭代器获取数据,传递给每个device去运算的代码,
    self._train_distribution.call_for_each_tower

    features, labels = estimator_util.parse_iterator_result(
                  iterator.get_next())
              grouped_estimator_spec = self._train_distribution.call_for_each_tower(
                  self._call_model_fn,
                  features,
                  labels,  # although this will be None it seems
                  model_fn_lib.ModeKeys.TRAIN,
                  self.config)
              loss = self._train_distribution.unwrap(
                  self._train_distribution.reduce(
                      distribute_lib.get_loss_reduction(),
                      grouped_estimator_spec.loss,
                      destinations='/device:CPU:0'))[0]
              distributed_train_op = grouped_estimator_spec.train_op
    

    call_for_each_tower是每个设备训练的接口

    def _call_for_each_tower(distribution, fn, *args, **kwargs):
      """Run `fn` in separate threads, once per tower/worker device.
      run_concurrently = kwargs.pop("run_concurrently", True)
      if not context.executing_eagerly():
        # Lots of TF library code isn't thread-safe in graph mode, and
        # there is little to be gained by turning on multithreading when
        # constructing a graph.
        run_concurrently = False
        # Needed for per-thread device, etc. contexts in graph mode.
        ops.get_default_graph().switch_to_thread_local()
      elif run_concurrently is None:
        run_concurrently = True
    
      coord = coordinator.Coordinator(clean_stop_exception_types=(_RequestedStop,))
    
      shared_variable_store = {}
    
      # TODO(isaprykin): Create these threads once instead of during every run()
      # call.
      threads = []
      for index, d in enumerate(distribution.worker_devices):
        variable_creator_fn = shared_variable_creator.make_fn(
            shared_variable_store, index)
        t = MirroredStrategy._MirroredTowerThread(  # pylint: disable=protected-access
            distribution, coord, d, variable_creator_fn, fn,
            *values.select_device(d, args), **values.select_device(d, kwargs))
        threads.append(t)
    
      for t in threads:
        t.start()
    

    其中,select_device就是取对应设备key对应的值。完成整个分布式训练。

  • 相关阅读:
    async/await使用深入详解
    尴尬的事情又发生Newtonsoft.Json vs Protobuf.net
    在dotnet core下去中心化访问HTTP服务集群
    条件随机场CRF(一)从随机场到线性链条件随机场
    用hmmlearn学习隐马尔科夫模型HMM
    隐马尔科夫模型HMM(四)维特比算法解码隐藏状态序列
    隐马尔科夫模型HMM(三)鲍姆-韦尔奇算法求解HMM参数
    隐马尔科夫模型HMM(二)前向后向算法评估观察序列概率
    隐马尔科夫模型HMM(一)HMM模型
    EM算法原理总结
  • 原文地址:https://www.cnblogs.com/axder/p/11459103.html
Copyright © 2011-2022 走看看