zoukankan      html  css  js  c++  java
  • PyTorch之DataLoader杂谈

     输入数据PipeLine

    pytorch 的数据加载到模型的操作顺序是这样的:

    ①创建一个 Dataset 对象
    ②创建一个 DataLoader 对象
    ③循环这个 DataLoader 对象,将img, label加载到模型中进行训练

    dataset = MyDataset()
    dataloader = DataLoader(dataset)
    num_epoches = 100
    for epoch in range(num_epoches):
        for img, label in dataloader:
            ....
    

    所以,作为直接对数据进入模型中的关键一步, DataLoader非常重要。

    首先简单介绍一下DataLoader,它是PyTorch中数据读取的一个重要接口,该接口定义在dataloader.py中,只要是用PyTorch来训练模型基本都会用到该接口(除非用户重写…),该接口的目的:将自定义的Dataset根据batch size大小、是否shuffle等封装成一个Batch Size大小的Tensor,用于后面的训练。

    DataLoader的官方说明是:

    “数据加载由数据集和采样器组成,基于python的单、多进程的iterators来处理数据。”

    关于iterator和iterable的区别和概念若有兴趣请自行查阅,在实现中的差别就是iterators有__iter__和__next__方法,而iterable只有__iter__方法。
    

    1.DataLoader源码剖析

    在分析源码之前,先介绍一下DataLoader(object)的参数:

    • dataset(Dataset): 传入的数据集
    • batch_size(int, optional): 每个batch有多少个样本
    • shuffle(bool, optional): 在每个epoch开始的时候,对数据进行重新排序
    • sampler(Sampler, optional): 自定义从数据集中取样本的策略,如果指定这个参数,那么shuffle必须为False
    • batch_sampler(Sampler, optional): 与sampler类似,但是一次只返回一个batch的indices(索引),需要注意的是,一旦指定了这个参数,那么batch_size,shuffle,sampler,drop_last就不能再制定了(互斥——Mutually exclusive)
    • num_workers (int, optional): 这个参数决定了有几个进程来处理data loading。0意味着所有的数据都会被load进主进程。(默认为0)
    • collate_fn (callable, optional): 将一个list的sample组成一个mini-batch的函数
    • pin_memory (bool, optional): 如果设置为True,那么data loader将会在返回它们之前,将tensors拷贝到CUDA中的固定内存(CUDA pinned memory)中.
    • drop_last (bool, optional): 如果设置为True:这个是对最后的未完成的batch来说的,比如你的batch_size设置为64,而一个epoch只有100个样本,那么训练的时候后面的36个就被扔掉了…如果为False(默认),那么会继续正常执行,只是最后的batch_size会小一点。
    • timeout(numeric, optional): 如果是正数,表明等待从worker进程中收集一个batch等待的时间,若超出设定的时间还没有收集到,那就不收集这个内容了。这个numeric应总是大于等于0。默认为0
    • worker_init_fn (callable, optional): 每个worker初始化函数 If not None, this will be called on each worker subprocess with the worker id (an int in [0, num_workers - 1]) as input, after seeding and before data loading. (default: None)

    显然,根据上面参数的解释,DataLoader这个类就是进行数据的初始化的操作,好了,下面来看源码吧:

    class DataLoader(object):
    
        __initialized = False
    
        def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
                     batch_sampler=None, num_workers=0, collate_fn=default_collate,
                     pin_memory=False, drop_last=False, timeout=0,
                     worker_init_fn=None):
            self.dataset = dataset
            self.batch_size = batch_size
            self.num_workers = num_workers
            self.collate_fn = collate_fn
            self.pin_memory = pin_memory
            self.drop_last = drop_last
            self.timeout = timeout
            self.worker_init_fn = worker_init_fn
    
            if timeout < 0:
                raise ValueError('timeout option should be non-negative')
    
            if batch_sampler is not None:
                if batch_size > 1 or shuffle or sampler is not None or drop_last:
                    raise ValueError('batch_sampler option is mutually exclusive '
                                     'with batch_size, shuffle, sampler, and '
                                     'drop_last')
                self.batch_size = None
                self.drop_last = None
    
            if sampler is not None and shuffle:
                raise ValueError('sampler option is mutually exclusive with '
                                 'shuffle')
    
            if self.num_workers < 0:
                raise ValueError('num_workers option cannot be negative; '
                                 'use num_workers=0 to disable multiprocessing.')
    
            if batch_sampler is None:
                if sampler is None:
                    if shuffle:
                        sampler = RandomSampler(dataset)
                    else:
                        sampler = SequentialSampler(dataset)
                batch_sampler = BatchSampler(sampler, batch_size, drop_last)
    
            self.sampler = sampler
            self.batch_sampler = batch_sampler
            self.__initialized = True
    
        def __setattr__(self, attr, val):
            if self.__initialized and attr in ('batch_size', 'sampler', 'drop_last'):
                raise ValueError('{} attribute should not be set after {} is '
                                 'initialized'.format(attr, self.__class__.__name__))
    
            super(DataLoader, self).__setattr__(attr, val)
    
        def __iter__(self):
            return _DataLoaderIter(self)
    
        def __len__(self):
            return len(self.batch_sampler)
    

    这里主要看__init__()和__iter__():

    ①数据的shuffle和batch处理

    • RandomSampler(dataset)
    • SequentialSampler(dataset)
    • BatchSampler(sampler, batch_size, drop_last)

    ②因为DataLoader只有__iter__()而没有实现__next__(),所以DataLoader是一个iterable而不是iterator。这个iterator的实现在_DataLoaderIter中。

    1.1DataLoader之RandomSampler(dataset)、 SequentialSampler(dataset)

    实现是在dataloader.py的同级目录下的torch/utils/data/sampler.pysampler.py中实现了一个父类Sampler,以及SequentialSamplerRandomSamplerBatchSampler等五个继承Sampler的子类。对每个采样器,都需要提供__iter__方法用以表示数据遍历的方式和__len__方法用以返回数据的长度。

    class Sampler(object):
        r"""Base class for all Samplers.
        Every Sampler subclass has to provide an __iter__ method, providing a way
        to iterate over indices of dataset elements, and a __len__ method that
        returns the length of the returned iterators.
        """
    
        def __init__(self, data_source):
            pass
    
        def __iter__(self):
            raise NotImplementedError
    
        def __len__(self):
    raise NotImplementedError
    
    
    class SequentialSampler(Sampler):
        r"""Samples elements sequentially, always in the same order.
        Arguments:
            data_source (Dataset): dataset to sample from
        """
    
        def __init__(self, data_source):
            self.data_source = data_source
    
        def __iter__(self):
            return iter(range(len(self.data_source)))
    
        def __len__(self):
            return len(self.data_source)
    
    
    class RandomSampler(Sampler):
        r"""Samples elements randomly, without replacement.
        Arguments:
            data_source (Dataset): dataset to sample from
        """
    
        def __init__(self, data_source):
            self.data_source = data_source
    
        def __iter__(self):
            return iter(torch.randperm(len(self.data_source)).tolist())
    
        def __len__(self):
    return len(self.data_source)
    
    if __name__ == "__main__":
    	print(list(RandomSampler(range(10))))
    	#[2, 8, 3, 5, 9, 4, 6, 0, 1, 7]
    	print(list(SequentialSampler(range(10))))
    	#[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
    

    可以看出RandomSampler等方法返回的就是DataSet中的索引位置(indices),其中,在子类中的__iter__方法中,需要返回的是iter(xxx)(即iterator)的形式:

    #### 以下两个代码是等价的
    for data in dataloader:
        ...
    #### 等价与
    iters = iter(dataloader)
    while 1:
        try:
            next(iters)
        except StopIteration:
            break
    

    1.2 DataLoader之BatchSampler(Sampler)

    BatchSampler是wrap一个sampler,并生成mini-batch的索引(indices)的方式

    这里主要看__iter__方法,可以看到,代码的思路很清楚明白的展示了batch indices的是如何取出的.

    class BatchSampler(Sampler):
        r"""Wraps another sampler to yield a mini-batch of indices.
        Args:
            sampler (Sampler): Base sampler.
            batch_size (int): Size of mini-batch.
            drop_last (bool): If ``True``, the sampler will drop the last batch if
                its size would be less than ``batch_size``
        Example:
            >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
            [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
            >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
            [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
        """
    
        def __init__(self, sampler, batch_size, drop_last):
            if not isinstance(sampler, Sampler):
                raise ValueError("sampler should be an instance of "
                                 "torch.utils.data.Sampler, but got sampler={}"
                                 .format(sampler))
            if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or 
                    batch_size <= 0:
                raise ValueError("batch_size should be a positive integeral value, "
                                 "but got batch_size={}".format(batch_size))
            if not isinstance(drop_last, bool):
                raise ValueError("drop_last should be a boolean value, but got "
                                 "drop_last={}".format(drop_last))
            self.sampler = sampler
            self.batch_size = batch_size
            self.drop_last = drop_last
    
        def __iter__(self):
            batch = []
            # 一旦达到batch_size的长度,说明batch被填满,就可以yield出去了
            for idx in self.sampler:
                batch.append(idx)
                if len(batch) == self.batch_size:
                    yield batch
                    batch = []
            if len(batch) > 0 and not self.drop_last:
                yield batch
    
        def __len__(self):
            # 比如epoch有100个样本,batch_size选择为64,那么drop_last的结果为1,不drop_last的结果为2
            if self.drop_last:
                return len(self.sampler) // self.batch_size
            else:
                return (len(self.sampler) + self.batch_size - 1) // self.batch_size
    if __name__ == "__main__":
    	print(list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False)))
    	# [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
    	print(list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)))
    	# [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
    

    1.3 _DataLoaderIter

    _DataLoaderIter其实就是DataLoader类的__iter__()方法的返回值:

    class _DataLoaderIter(object):
        r"""Iterates once over the DataLoader's dataset, as specified by the sampler"""
    
        # NOTE [ Data Loader Multiprocessing Shutdown Logic ]
        #
        # Preliminary:
        #
        # Our data model looks like this (queues are indicated with curly brackets):
        #
        #                main process                              ||
        #                     |                                    ||
        #               {index_queue}                              ||
        #                     |                                    ||
        #              worker processes                            ||     DATA
        #                     |                                    ||
        #            {worker_result_queue}                         ||     FLOW
        #                     |                                    ||
        #      pin_memory_thread of main process                   ||   DIRECTION
        #                     |                                    ||
        #               {data_queue}                               ||
        #                     |                                    ||
        #                data output                               /
        #
        # P.S. `worker_result_queue` and `pin_memory_thread` part may be omitted if
        #      `pin_memory=False`.
        #
        #
        # Terminating multiprocessing logic requires very careful design. In
        # particular, we need to make sure that
        #
        #   1. The iterator gracefully exits the workers when its last reference is
        #      gone or it is depleted.
        #
        #      In this case, the workers should be gracefully exited because the
        #      main process may still need to continue to run, and we want cleaning
        #      up code in the workers to be executed (e.g., releasing GPU memory).
        #      Naturally, we implement the shutdown logic in `__del__` of
        #      DataLoaderIterator.
        #
        #      We delay the discussion on the logic in this case until later.
        #
        #   2. The iterator exits the workers when the loader process and/or worker
        #      processes exits normally or with error.
        #
        #      We set all workers and `pin_memory_thread` to have `daemon=True`.
        #
        #      You may ask, why can't we make the workers non-daemonic, and
        #      gracefully exit using the same logic as we have in `__del__` when the
        #      iterator gets deleted (see 1 above)?
        #
        #      First of all, `__del__` is **not** guaranteed to be called when
        #      interpreter exits. Even if it is called, by the time it executes,
        #      many Python core library resources may alreay be freed, and even
        #      simple things like acquiring an internal lock of a queue may hang.
        #      Therefore, in this case, we actually need to prevent `__del__` from
        #      being executed, and rely on the automatic termination of daemonic
        #      children. Thus, we register an `atexit` hook that sets a global flag
        #      `_utils.python_exit_status`. Since `atexit` hooks are executed in the
        #      reverse order of registration, we are guaranteed that this flag is
        #      set before library resources we use are freed. (Hooks freeing those
        #      resources are registered at importing the Python core libraries at
        #      the top of this file.) So in `__del__`, we check if
        #      `_utils.python_exit_status` is set or `None` (freed), and perform
        #      no-op if so.
        #
        #      Another problem with `__del__` is also related to the library cleanup
        #      calls. When a process ends, it shuts the all its daemonic children
        #      down with a SIGTERM (instead of joining them without a timeout).
        #      Simiarly for threads, but by a different mechanism. This fact,
        #      together with a few implementation details of multiprocessing, forces
        #      us to make workers daemonic. All of our problems arise when a
        #      DataLoader is used in a subprocess, and are caused by multiprocessing
        #      code which looks more or less like this:
        #
        #          try:
        #              your_function_using_a_dataloader()
        #          finally:
        #              multiprocessing.util._exit_function()
        #
        #      The joining/termination mentioned above happens inside
        #      `_exit_function()`. Now, if `your_function_using_a_dataloader()`
        #      throws, the stack trace stored in the exception will prevent the
        #      frame which uses `DataLoaderIter` to be freed. If the frame has any
        #      reference to the `DataLoaderIter` (e.g., in a method of the iter),
        #      its  `__del__`, which starts the shutdown procedure, will not be
        #      called. That, in turn, means that workers aren't notified. Attempting
        #      to join in `_exit_function` will then result in a hang.
        #
        #      For context, `_exit_function` is also registered as an `atexit` call.
        #      So it is unclear to me (@ssnl) why this is needed in a finally block.
        #      The code dates back to 2008 and there is no comment on the original
        #      PEP 371 or patch https://bugs.python.org/issue3050 (containing both
        #      the finally block and the `atexit` registration) that explains this.
        #
        #      Another choice is to just shutdown workers with logic in 1 above
        #      whenever we see an error in `next`. This isn't ideal because
        #        a. It prevents users from using try-catch to resume data loading.
        #        b. It doesn't prevent hanging if users have references to the
        #           iterator.
        #
        #   3. All processes exit if any of them die unexpectedly by fatal signals.
        #
        #      As shown above, the workers are set as daemonic children of the main
        #      process. However, automatic cleaning-up of such child processes only
        #      happens if the parent process exits gracefully (e.g., not via fatal
        #      signals like SIGKILL). So we must ensure that each process will exit
        #      even the process that should send/receive data to/from it were
        #      killed, i.e.,
        #
        #        a. A process won't hang when getting from a queue.
        #
        #           Even with carefully designed data dependencies (i.e., a `put()`
        #           always corresponding to a `get()`), hanging on `get()` can still
        #           happen when data in queue is corrupted (e.g., due to
        #           `cancel_join_thread` or unexpected exit).
        #
        #           For child exit, we set a timeout whenever we try to get data
        #           from `data_queue`, and check the workers' status on each timeout
        #           and error.
        #           See `_DataLoaderiter._get_batch()` and
        #           `_DataLoaderiter._try_get_batch()` for details.
        #
        #           Additionally, for child exit on non-Windows platforms, we also
        #           register a SIGCHLD handler (which is supported on Windows) on
        #           the main process, which checks if any of the workers fail in the
        #           (Python) handler. This is more efficient and faster in detecting
        #           worker failures, compared to only using the above mechanism.
        #           See `DataLoader.cpp` and `_utils/signal_handling.py` for details.
        #
        #           For `.get()` calls where the sender(s) is not the workers, we
        #           guard them with timeouts, and check the status of the sender
        #           when timeout happens:
        #             + in the workers, the `_utils.worker.ManagerWatchdog` class
        #               checks the status of the main process.
        #             + if `pin_memory=True`, when getting from `pin_memory_thread`,
        #               check `pin_memory_thread` status periodically until `.get()`
        #               returns or see that `pin_memory_thread` died.
        #
        #        b. A process won't hang when putting into a queue;
        #
        #           We use `mp.Queue` which has a separate background thread to put
        #           objects from an unbounded buffer array. The background thread is
        #           daemonic and usually automatically joined when the process
        #           exits.
        #
        #           However, in case that the receiver has ended abruptly while
        #           reading from the pipe, the join will hang forever. Therefore,
        #           for both `worker_result_queue` (worker -> main process/pin_memory_thread)
        #           and each `index_queue` (main process -> worker), we use
        #           `q.cancel_join_thread()` in sender process before any `q.put` to
        #           prevent this automatic join.
        #
        #           Moreover, having all queues called `cancel_join_thread` makes
        #           implementing graceful shutdown logic in `__del__` much easier.
        #           It won't need to get from any queue, which would also need to be
        #           guarded by periodic status checks.
        #
        #           Note that this may leave corrupted data in the queue, but we
        #           don't care about the data anyways once we are shutting down.
        #
        #
        # Now let's get back to 1:
        #   how we gracefully exit the workers when the last reference to the
        #   iterator is gone.
        #
        # To achieve this, we implement the following logic along with the design
        # choices mentioned above:
        #
        # [worker processes]
        #   While loader process is alive:
        #     Get from index_queue.
        #       If got a `None`, exit.
        #       If get anything else,
        #          Check `done_event`.
        #            If set, continue to next iteration
        #                    i.e., keep getting until see the `None`, then exit.
        #            Otherwise, process data.
        #       If timed out,
        #          No matter `done_event` is set (still need to see `None`) or not,
        #          must continue to next iteration .
        #
        # [pin_memory_thread]
        #   # No need to check main thread. If this thread is alive, the main loader
        #   # thread must be alive, because this thread is set as daemonic.
        #   While True:
        #     Get from index_queue.
        #       If got a `None`, exit.
        #       If get anything else,
        #          Check `done_event`.
        #            If set, continue to next iteration
        #                    i.e., keep getting until see the `None`, then exit.
        #            Otherwise, process data.
        #
        #   NOTE: we don't check the status of the main thread because
        #           1. if the process is killed by fatal signal, `pin_memory_thread`
        #              ends.
        #           2. in other cases, either the cleaning-up in __del__ or the
        #              automatic exit of daemonic thread will take care of it.
        #              This won't busy-wait either because `.get(timeout)` does not
        #              busy-wait.
        #
        # [main process]
        #   In the DataLoader Iter's `__del__`
        #     a. Set `done_event` (shared with `pin_memory_thread` and workers).
        #
        #        Note: from here on, the workers & `pin_memory_thread` may exit at
        #              any time after they receive `None`.
        #
        #     b. Exit `pin_memory_thread`
        #          i.   Put `None` in `worker_result_queue`.
        #          ii.  Join the `pin_memory_thread`.
        #
        #     c. Exit the workers.
        #          i.   Put `None` in each worker's `index_queue`.
        #          ii.  Join the workers.
        #
        #        NOTE: This has to be after (b) because it may leave corrupted data
        #              in `worker_result_queue`, which `pin_memory_thread` reads
        #              from.
        #
        #   NOTE: If `pin_memory=False`, there is no `pin_memory_thread` and (b)
        #         can be omitted
        #
        # NB: `done_event`s isn't strictly needed. E.g., we can just check for
        #     `None` from `index_queue`, but it allows us to skip wasting resources
        #     processing indices already in `index_queue` if we are already shutting
        #     down.
    
        def __init__(self, loader):
            self.dataset = loader.dataset
            self.collate_fn = loader.collate_fn
            self.batch_sampler = loader.batch_sampler
            self.num_workers = loader.num_workers
            self.pin_memory = loader.pin_memory and torch.cuda.is_available()
            self.timeout = loader.timeout
    
            self.sample_iter = iter(self.batch_sampler)
    
            base_seed = torch.LongTensor(1).random_().item()
    
            if self.num_workers > 0:
                self.worker_init_fn = loader.worker_init_fn
                self.worker_queue_idx = 0
                self.worker_result_queue = multiprocessing.Queue()
                self.batches_outstanding = 0
                self.worker_pids_set = False
                self.shutdown = False
                self.send_idx = 0
                self.rcvd_idx = 0
                self.reorder_dict = {}
                self.done_event = multiprocessing.Event()
    
                self.index_queues = []
                self.workers = []
                for i in range(self.num_workers):
                    index_queue = multiprocessing.Queue()
                    index_queue.cancel_join_thread()
                    w = multiprocessing.Process(
                        target=_utils.worker._worker_loop,
                        args=(self.dataset, index_queue,
                              self.worker_result_queue, self.done_event,
                              self.collate_fn, base_seed + i,
                              self.worker_init_fn, i))
                    w.daemon = True
                    # NB: Process.start() actually take some time as it needs to
                    #     start a process and pass the arguments over via a pipe.
                    #     Therefore, we only add a worker to self.workers list after
                    #     it started, so that we do not call .join() if program dies
                    #     before it starts, and __del__ tries to join but will get:
                    #     AssertionError: can only join a started process.
                    w.start()
                    self.index_queues.append(index_queue)
                    self.workers.append(w)
    
                if self.pin_memory:
                    self.data_queue = queue.Queue()
                    pin_memory_thread = threading.Thread(
                        target=_utils.pin_memory._pin_memory_loop,
                        args=(self.worker_result_queue, self.data_queue,
                              torch.cuda.current_device(), self.done_event))
                    pin_memory_thread.daemon = True
                    pin_memory_thread.start()
                    # Similar to workers (see comment above), we only register
                    # pin_memory_thread once it is started.
                    self.pin_memory_thread = pin_memory_thread
                else:
                    self.data_queue = self.worker_result_queue
    
                _utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self.workers))
                _utils.signal_handling._set_SIGCHLD_handler()
                self.worker_pids_set = True
    
                # prime the prefetch loop
                for _ in range(2 * self.num_workers):
                    self._put_indices()
    
        def __len__(self):
            return len(self.batch_sampler)
    
        def _try_get_batch(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL):
            # Tries to fetch data from `data_queue` for a given timeout. This can
            # also be used as inner loop of fetching without timeout, with the
            # sender status as the loop condition.
            #
            # This raises a `RuntimeError` if any worker died expectedly. This error
            # can come from either the SIGCHLD handler in `_utils/signal_handling.py`
            # (only for non-Windows platforms), or the manual check below on errors
            # and timeouts.
            #
            # Returns a 2-tuple:
            #   (bool: whether successfully get data, any: data if successful else None)
            try:
                data = self.data_queue.get(timeout=timeout)
                return (True, data)
            except Exception as e:
                # At timeout and error, we manually check whether any worker has
                # failed. Note that this is the only mechanism for Windows to detect
                # worker failures.
                if not all(w.is_alive() for w in self.workers):
                    pids_str = ', '.join(str(w.pid) for w in self.workers if not w.is_alive())
                    raise RuntimeError('DataLoader worker (pid(s) {}) exited unexpectedly'.format(pids_str))
                if isinstance(e, queue.Empty):
                    return (False, None)
                raise
    
        def _get_batch(self):
            # Fetches data from `self.data_queue`.
            #
            # We check workers' status every `MP_STATUS_CHECK_INTERVAL` seconds,
            # which we achieve by running `self._try_get_batch(timeout=MP_STATUS_CHECK_INTERVAL)`
            # in a loop. This is the only mechanism to detect worker failures for
            # Windows. For other platforms, a SIGCHLD handler is also used for
            # worker failure detection.
            #
            # If `pin_memory=True`, we also need check if `pin_memory_thread` had
            # died at timeouts.
            if self.timeout > 0:
                success, data = self._try_get_batch(self.timeout)
                if success:
                    return data
                else:
                    raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout))
            elif self.pin_memory:
                while self.pin_memory_thread.is_alive():
                    success, data = self._try_get_batch()
                    if success:
                        return data
                else:
                    # while condition is false, i.e., pin_memory_thread died.
                    raise RuntimeError('Pin memory thread exited unexpectedly')
                # In this case, `self.data_queue` is a `queue.Queue`,. But we don't
                # need to call `.task_done()` because we don't use `.join()`.
            else:
                while True:
                    success, data = self._try_get_batch()
                    if success:
                        return data
    
        def __next__(self):
            if self.num_workers == 0:  # same-process loading
                indices = next(self.sample_iter)  # may raise StopIteration
                batch = self.collate_fn([self.dataset[i] for i in indices])
                if self.pin_memory:
                    batch = _utils.pin_memory.pin_memory_batch(batch)
                return batch
    
            # check if the next sample has already been generated
            if self.rcvd_idx in self.reorder_dict:
                batch = self.reorder_dict.pop(self.rcvd_idx)
                return self._process_next_batch(batch)
    
            if self.batches_outstanding == 0:
                self._shutdown_workers()
                raise StopIteration
    
            while True:
                assert (not self.shutdown and self.batches_outstanding > 0)
                idx, batch = self._get_batch()
                self.batches_outstanding -= 1
                if idx != self.rcvd_idx:
                    # store out-of-order samples
                    self.reorder_dict[idx] = batch
                    continue
                return self._process_next_batch(batch)
    
        next = __next__  # Python 2 compatibility
    
        def __iter__(self):
            return self
    
        def _put_indices(self):
            assert self.batches_outstanding < 2 * self.num_workers
            indices = next(self.sample_iter, None)
            if indices is None:
                return
            self.index_queues[self.worker_queue_idx].put((self.send_idx, indices))
            self.worker_queue_idx = (self.worker_queue_idx + 1) % self.num_workers
            self.batches_outstanding += 1
            self.send_idx += 1
    
        def _process_next_batch(self, batch):
            self.rcvd_idx += 1
            self._put_indices()
            if isinstance(batch, _utils.ExceptionWrapper):
                # make multiline KeyError msg readable by working around
                # a python bug https://bugs.python.org/issue2651
                if batch.exc_type == KeyError and "
    " in batch.exc_msg:
                    raise Exception("KeyError:" + batch.exc_msg)
                else:
                    raise batch.exc_type(batch.exc_msg)
            return batch
    
        def __getstate__(self):
            # TODO: add limited pickling support for sharing an iterator
            # across multiple threads for HOGWILD.
            # Probably the best way to do this is by moving the sample pushing
            # to a separate thread and then just sharing the data queue
            # but signalling the end is tricky without a non-blocking API
            raise NotImplementedError("_DataLoaderIter cannot be pickled")
    
        def _shutdown_workers(self):
            # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on
            # the logic of this function.
            python_exit_status = _utils.python_exit_status
            if python_exit_status is True or python_exit_status is None:
                # See (2) of the note. If Python is shutting down, do no-op.
                return
            # Normal exit when last reference is gone / iterator is depleted.
            # See (1) and the second half of the note.
            if not self.shutdown:
                self.shutdown = True
                try:
                    self.done_event.set()
    
                    # Exit `pin_memory_thread` first because exiting workers may leave
                    # corrupted data in `worker_result_queue` which `pin_memory_thread`
                    # reads from.
                    if hasattr(self, 'pin_memory_thread'):
                        # Use hasattr in case error happens before we set the attribute.
                        # First time do `worker_result_queue.put` in this process.
    
                        # `cancel_join_thread` in case that `pin_memory_thread` exited.
                        self.worker_result_queue.cancel_join_thread()
                        self.worker_result_queue.put(None)
                        self.pin_memory_thread.join()
                        # Indicate that no more data will be put on this queue by the
                        # current process. This **must** be called after
                        # `pin_memory_thread` is joined because that thread shares the
                        # same pipe handles with this loader thread. If the handle is
                        # closed, Py3 will error in this case, but Py2 will just time
                        # out even if there is data in the queue.
                        self.worker_result_queue.close()
    
                    # Exit workers now.
                    for q in self.index_queues:
                        q.put(None)
                        # Indicate that no more data will be put on this queue by the
                        # current process.
                        q.close()
                    for w in self.workers:
                        w.join()
                finally:
                    # Even though all this function does is putting into queues that
                    # we have called `cancel_join_thread` on, weird things can
                    # happen when a worker is killed by a signal, e.g., hanging in
                    # `Event.set()`. So we need to guard this with SIGCHLD handler,
                    # and remove pids from the C side data structure only at the
                    # end.
                    #
                    # FIXME: Unfortunately, for Windows, we are missing a worker
                    #        error detection mechanism here in this function, as it
                    #        doesn't provide a SIGCHLD handler.
                    if self.worker_pids_set:
                        _utils.signal_handling._remove_worker_pids(id(self))
                        self.worker_pids_set = False
    
        def __del__(self):
            if self.num_workers > 0:
                self._shutdown_workers()
    

    self.index_queue = multiprocessing.SimpleQueue()中的multiprocessing是Python中的多进程管理包,而threading则是Python中的多线程管理包,二者很大一部分的接口用法类似。

    __init__函数中,前面部分都是一些赋值操作,比较特殊的是self.sample_iter = iter(self.batch_sampler),得到的self.sample_iter可以通过next(self.sample_iter)来获取batch size个数据的index。self.rcvd_idx表示读取到的一个batch数据的index,初始化为0,该值在迭代读取数据的时候会用到。if self.num_workers语句是针对多进程或单进程的情况进行初始化,如果不是设置为多进程读取数据,那么就不需要这些初始化操作,后面会介绍单进程数据读取。在if语句中通过multiprocessing.SimpleQueue()类创建了一个简单的队列对象。multiprocessing.Process类就是构造进程的类,这里根据设定的进程数来启动,然后赋值给self.workers。接下来的一个for循环就通过调用start方法依次启动self.workers中的进程。接下来关于self.pin_memory的判断语句,该判断语句内部主要是实现了多线程操作。self.pin_memory的含义在前面已经介绍过了,当为True的时候,就会把数据拷到CUDA中。self.data_queue = queue.Queue()是通过Python的queue模块初始化得到一个先进先出的队列(queue模块也可以初始化得到先进后出的队列,需要用queue.LifoQueue()初始化),queue模块主要应用在多线程读取数据中。在threading.Thread的args参数中,第一个参数in_data就是一个进程的数据,一个进程中不同线程的数据也是通过队列来维护的,这里采用的是Python的queue模块来初始化得到一个队列:queue.Queue()。初始化结束后,就会调用__next__方法。 

    总的来说,如果设置为多进程读取数据,那么就会采用队列的方式来读,如果不是采用多进程来读取数据,那就采用普通方式来读。

      

    PyTorch学习笔记(6)——DataLoader源代码剖析

    pytorch学习笔记(十四): DataLoader源码阅读

  • 相关阅读:
    JAVA学习前应该了解
    JAVA帝国的诞生
    常用的快捷方式
    MarkDown学习
    运动检测
    图像分割
    感知机
    线性判别函数
    距离
    概率密度估计笔记——非参数估计
  • 原文地址:https://www.cnblogs.com/carsonzhu/p/11309161.html
Copyright © 2011-2022 走看看