zoukankan      html  css  js  c++  java
  • pytorch之dataloader深入剖析

    - dataloader本质是一个可迭代对象,使用iter()访问,不能使用next()访问;

    - 使用iter(dataloader)返回的是一个迭代器,然后可以使用next访问;

    - 也可以使用`for inputs, labels in dataloaders`进行可迭代对象的访问;

    - 一般我们实现一个datasets对象,传入到dataloader中;然后内部使用yeild返回每一次batch的数据;

    ① DataLoader本质上就是一个iterable(跟python的内置类型list等一样),并利用多进程来加速batch data的处理,使用yield来使用有限的内存
    ​
    ② Queue的特点
    
    当队列里面没有数据时: queue.get() 会阻塞, 阻塞的时候,其它进程/线程如果有queue.put() 操作,本线程/进程会被通知,然后就可以 get 成功。
    当数据满了: queue.put() 会阻塞
    
    ③ DataLoader是一个高效,简洁,直观的网络输入数据结构,便于使用和扩展

    输入数据PipeLine
    pytorch 的数据加载到模型的操作顺序是这样的:

    ① 创建一个 Dataset 对象
    ② 创建一个 DataLoader 对象
    ③ 循环这个 DataLoader 对象,将img, label加载到模型中进行训练

    dataset = MyDataset()
    dataloader = DataLoader(dataset)
    num_epoches = 100
    for epoch in range(num_epoches):
        for img, label in dataloader:
            ....
    所以,作为直接对数据进入模型中的关键一步, DataLoader非常重要。

    首先简单介绍一下DataLoader,它是PyTorch中数据读取的一个重要接口,该接口定义在dataloader.py中,只要是用PyTorch来训练模型基本都会用到该接口(除非用户重写…),该接口的目的:将自定义的Dataset根据batch size大小、是否shuffle等封装成一个Batch Size大小的Tensor,用于后面的训练。

    官方对DataLoader的说明是:“数据加载由数据集和采样器组成,基于python的单、多进程的iterators来处理数据。”关于iterator和iterable的区别和概念请自行查阅,在实现中的差别就是iterators有__iter__和__next__方法,而iterable只有__iter__方法。

    1.DataLoader

    先介绍一下DataLoader(object)的参数:

        dataset(Dataset): 传入的数据集
        batch_size(int, optional): 每个batch有多少个样本
        shuffle(bool, optional): 在每个epoch开始的时候,对数据进行重新排序
        sampler(Sampler, optional): 自定义从数据集中取样本的策略,如果指定这个参数,那么shuffle必须为False
        batch_sampler(Sampler, optional): 与sampler类似,但是一次只返回一个batch的indices(索引),需要注意的是,一旦指定了这个参数,那么batch_size,shuffle,sampler,drop_last就不能再制定了(互斥——Mutually exclusive)
        num_workers (int, optional): 这个参数决定了有几个进程来处理data loading。0意味着所有的数据都会被load进主进程。(默认为0)
        collate_fn (callable, optional): 将一个list的sample组成一个mini-batch的函数
        pin_memory (bool, optional): 如果设置为True,那么data loader将会在返回它们之前,将tensors拷贝到CUDA中的固定内存(CUDA pinned memory)中.
    
        drop_last (bool, optional): 如果设置为True:这个是对最后的未完成的batch来说的,比如你的batch_size设置为64,而一个epoch只有100个样本,那么训练的时候后面的36个就被扔掉了…
        如果为False(默认),那么会继续正常执行,只是最后的batch_size会小一点。
    
        timeout(numeric, optional): 如果是正数,表明等待从worker进程中收集一个batch等待的时间,若超出设定的时间还没有收集到,那就不收集这个内容了。这个numeric应总是大于等于0。默认为0
        worker_init_fn (callable, optional): 每个worker初始化函数 If not None, this will be called on each
        worker subprocess with the worker id (an int in [0, num_workers - 1]) as
        input, after seeding and before data loading. (default: None) 

    - 首先dataloader初始化时得到datasets的采样list

    class DataLoader(object):
        r"""
        Data loader. Combines a dataset and a sampler, and provides
        single- or multi-process iterators over the dataset.
    
        Arguments:
            dataset (Dataset): dataset from which to load the data.
            batch_size (int, optional): how many samples per batch to load
                (default: 1).
            shuffle (bool, optional): set to ``True`` to have the data reshuffled
                at every epoch (default: False).
            sampler (Sampler, optional): defines the strategy to draw samples from
                the dataset. If specified, ``shuffle`` must be False.
            batch_sampler (Sampler, optional): like sampler, but returns a batch of
                indices at a time. Mutually exclusive with batch_size, shuffle,
                sampler, and drop_last.
            num_workers (int, optional): how many subprocesses to use for data
                loading. 0 means that the data will be loaded in the main process.
                (default: 0)
            collate_fn (callable, optional): merges a list of samples to form a mini-batch.
            pin_memory (bool, optional): If ``True``, the data loader will copy tensors
                into CUDA pinned memory before returning them.
            drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
                if the dataset size is not divisible by the batch size. If ``False`` and
                the size of dataset is not divisible by the batch size, then the last batch
                will be smaller. (default: False)
            timeout (numeric, optional): if positive, the timeout value for collecting a batch
                from workers. Should always be non-negative. (default: 0)
            worker_init_fn (callable, optional): If not None, this will be called on each
                worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
                input, after seeding and before data loading. (default: None)
    
        .. note:: By default, each worker will have its PyTorch seed set to
                  ``base_seed + worker_id``, where ``base_seed`` is a long generated
                  by main process using its RNG. However, seeds for other libraies
                  may be duplicated upon initializing workers (w.g., NumPy), causing
                  each worker to return identical random numbers. (See
                  :ref:`dataloader-workers-random-seed` section in FAQ.) You may
                  use ``torch.initial_seed()`` to access the PyTorch seed for each
                  worker in :attr:`worker_init_fn`, and use it to set other seeds
                  before data loading.
    
        .. warning:: If ``spawn`` start method is used, :attr:`worker_init_fn` cannot be an
                     unpicklable object, e.g., a lambda function.
        """
    
        __initialized = False
    
        def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
                     num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False,
                     timeout=0, worker_init_fn=None):
            self.dataset = dataset
            self.batch_size = batch_size
            self.num_workers = num_workers
            self.collate_fn = collate_fn
            self.pin_memory = pin_memory
            self.drop_last = drop_last
            self.timeout = timeout
            self.worker_init_fn = worker_init_fn
    
            if timeout < 0:
                raise ValueError('timeout option should be non-negative')
    
            if batch_sampler is not None:
                if batch_size > 1 or shuffle or sampler is not None or drop_last:
                    raise ValueError('batch_sampler option is mutually exclusive '
                                     'with batch_size, shuffle, sampler, and '
                                     'drop_last')
                self.batch_size = None
                self.drop_last = None
    
            if sampler is not None and shuffle:
                raise ValueError('sampler option is mutually exclusive with '
                                 'shuffle')
    
            if self.num_workers < 0:
                raise ValueError('num_workers option cannot be negative; '
                                 'use num_workers=0 to disable multiprocessing.')
    
            if batch_sampler is None:
                if sampler is None:
                    if shuffle:
                        sampler = RandomSampler(dataset)  //将list打乱
                    else:
                        sampler = SequentialSampler(dataset)
                batch_sampler = BatchSampler(sampler, batch_size, drop_last)
    
            self.sampler = sampler
            self.batch_sampler = batch_sampler
            self.__initialized = True
    
        def __setattr__(self, attr, val):
            if self.__initialized and attr in ('batch_size', 'sampler', 'drop_last'):
                raise ValueError('{} attribute should not be set after {} is '
                                 'initialized'.format(attr, self.__class__.__name__))
    
            super(DataLoader, self).__setattr__(attr, val)
    
        def __iter__(self):
            return _DataLoaderIter(self)
    
        def __len__(self):
            return len(self.batch_sampler)

    其中:RandomSampler,BatchSampler已经得到了采用batch数据的index索引;yield batch机制已经在!!!

    class RandomSampler(Sampler):
        r"""Samples elements randomly, without replacement.
    
        Arguments:
            data_source (Dataset): dataset to sample from
        """
    
        def __init__(self, data_source):
            self.data_source = data_source
    
        def __iter__(self):
            return iter(torch.randperm(len(self.data_source)).tolist())
    
        def __len__(self):
            return len(self.data_source)
    class BatchSampler(Sampler):
        r"""Wraps another sampler to yield a mini-batch of indices.
    
        Args:
            sampler (Sampler): Base sampler.
            batch_size (int): Size of mini-batch.
            drop_last (bool): If ``True``, the sampler will drop the last batch if
                its size would be less than ``batch_size``
    
        Example:
            >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
            [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
            >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
            [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
        """
    
        def __init__(self, sampler, batch_size, drop_last):
            if not isinstance(sampler, Sampler):
                raise ValueError("sampler should be an instance of "
                                 "torch.utils.data.Sampler, but got sampler={}"
                                 .format(sampler))
            if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or 
                    batch_size <= 0:
                raise ValueError("batch_size should be a positive integeral value, "
                                 "but got batch_size={}".format(batch_size))
            if not isinstance(drop_last, bool):
                raise ValueError("drop_last should be a boolean value, but got "
                                 "drop_last={}".format(drop_last))
            self.sampler = sampler
            self.batch_size = batch_size
            self.drop_last = drop_last
    
        def __iter__(self):
            batch = []
            for idx in self.sampler:
                batch.append(idx)
                if len(batch) == self.batch_size:
                    yield batch
                    batch = []
            if len(batch) > 0 and not self.drop_last:
                yield batch
    
        def __len__(self):
            if self.drop_last:
                return len(self.sampler) // self.batch_size
            else:
                return (len(self.sampler) + self.batch_size - 1) // self.batch_size

    - 其中 _DataLoaderIter(self)输入为一个dataloader对象;如果num_workers=0很好理解,num_workers!=0引入多线程机制,加速数据加载过程;

    - 没有多线程时:batch = self.collate_fn([self.dataset[i] for i in indices])进行将index转化为data数据,返回(image,label);self.dataset[i]会调用datasets对象的

    __getitem__()方法;

    - 多线程下,会为每个线程创建一个索引队列index_queues;共享一个worker_result_queue数据队列!在_worker_loop方法中加载数据;

    class _DataLoaderIter(object):
        r"""Iterates once over the DataLoader's dataset, as specified by the sampler"""
    
        def __init__(self, loader):
            self.dataset = loader.dataset
            self.collate_fn = loader.collate_fn
            self.batch_sampler = loader.batch_sampler
            self.num_workers = loader.num_workers
            self.pin_memory = loader.pin_memory and torch.cuda.is_available()
            self.timeout = loader.timeout
            self.done_event = threading.Event()
    
            self.sample_iter = iter(self.batch_sampler)
    
            base_seed = torch.LongTensor(1).random_().item()
    
            if self.num_workers > 0:
                self.worker_init_fn = loader.worker_init_fn
                self.index_queues = [multiprocessing.Queue() for _ in range(self.num_workers)]
                self.worker_queue_idx = 0
                self.worker_result_queue = multiprocessing.SimpleQueue()
                self.batches_outstanding = 0
                self.worker_pids_set = False
                self.shutdown = False
                self.send_idx = 0
                self.rcvd_idx = 0
                self.reorder_dict = {}
    
                self.workers = [
                    multiprocessing.Process(
                        target=_worker_loop,
                        args=(self.dataset, self.index_queues[i],
                              self.worker_result_queue, self.collate_fn, base_seed + i,
                              self.worker_init_fn, i))
                    for i in range(self.num_workers)]
    
                if self.pin_memory or self.timeout > 0:
                    self.data_queue = queue.Queue()
                    if self.pin_memory:
                        maybe_device_id = torch.cuda.current_device()
                    else:
                        # do not initialize cuda context if not necessary
                        maybe_device_id = None
                    self.worker_manager_thread = threading.Thread(
                        target=_worker_manager_loop,
                        args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory,
                              maybe_device_id))
                    self.worker_manager_thread.daemon = True
                    self.worker_manager_thread.start()
                else:
                    self.data_queue = self.worker_result_queue
    
                for w in self.workers:
                    w.daemon = True  # ensure that the worker exits on process exit
                    w.start()
    
                _update_worker_pids(id(self), tuple(w.pid for w in self.workers))
                _set_SIGCHLD_handler()
                self.worker_pids_set = True
    
                # prime the prefetch loop
                for _ in range(2 * self.num_workers):
                    self._put_indices()
    
        def __len__(self):
            return len(self.batch_sampler)
    
        def _get_batch(self):
            if self.timeout > 0:
                try:
                    return self.data_queue.get(timeout=self.timeout)
                except queue.Empty:
                    raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout))
            else:
                return self.data_queue.get()
    
        def __next__(self):
            if self.num_workers == 0:  # same-process loading
                indices = next(self.sample_iter)  # may raise StopIteration
                batch = self.collate_fn([self.dataset[i] for i in indices])
                if self.pin_memory:
                    batch = pin_memory_batch(batch)
                return batch
    
            # check if the next sample has already been generated
            if self.rcvd_idx in self.reorder_dict:
                batch = self.reorder_dict.pop(self.rcvd_idx)
                return self._process_next_batch(batch)
    
            if self.batches_outstanding == 0:
                self._shutdown_workers()
                raise StopIteration
    
            while True:
                assert (not self.shutdown and self.batches_outstanding > 0)
                idx, batch = self._get_batch()
                self.batches_outstanding -= 1
                if idx != self.rcvd_idx:
                    # store out-of-order samples
                    self.reorder_dict[idx] = batch
                    continue
                return self._process_next_batch(batch)
    
        next = __next__  # Python 2 compatibility
    
        def __iter__(self):
            return self
    
        def _put_indices(self):
            assert self.batches_outstanding < 2 * self.num_workers
            indices = next(self.sample_iter, None)
            if indices is None:
                return
            self.index_queues[self.worker_queue_idx].put((self.send_idx, indices))
            self.worker_queue_idx = (self.worker_queue_idx + 1) % self.num_workers
            self.batches_outstanding += 1
            self.send_idx += 1
    
        def _process_next_batch(self, batch):
            self.rcvd_idx += 1
            self._put_indices()
            if isinstance(batch, ExceptionWrapper):
                raise batch.exc_type(batch.exc_msg)
            return batch
    def _worker_loop(dataset, index_queue, data_queue, collate_fn, seed, init_fn, worker_id):
        global _use_shared_memory
        _use_shared_memory = True
    
        # Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
        # module's handlers are executed after Python returns from C low-level
        # handlers, likely when the same fatal signal happened again already.
        # https://docs.python.org/3/library/signal.html Sec. 18.8.1.1
        _set_worker_signal_handlers()
    
        torch.set_num_threads(1)
        random.seed(seed)
        torch.manual_seed(seed)
    
        if init_fn is not None:
            init_fn(worker_id)
    
        watchdog = ManagerWatchdog()
    
        while True:
            try:
                r = index_queue.get(timeout=MANAGER_STATUS_CHECK_INTERVAL)
            except queue.Empty:
                if watchdog.is_alive():
                    continue
                else:
                    break
            if r is None:
                break
            idx, batch_indices = r
            try:
                samples = collate_fn([dataset[i] for i in batch_indices])
            except Exception:
                data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
            else:
                data_queue.put((idx, samples))
                del samples

    - 需要对队列操作,缓存数据,使得加载提速!

  • 相关阅读:
    How To Compile Qt with Visual Studio 2010
    VCL线程的同步方法 Synchronize(用消息来同步)
    Delphi中怎么结束线程(这个线程是定时执行的)(方案二)
    编程之美 寻找数组中的最大值和最小值
    Delphi中怎么结束线程(这个线程是定时执行的)(方案一)
    Delphi线程同步(临界区、互斥、信号量,包括详细代码)
    Delphi管理多线程之线程局部存储:threadvar
    Delphi之通过代码示例学习XML解析、StringReplace的用法(异常控制 good)
    Delphi的文件操作(定义,关联,打开,读写,关闭)
    Android 中单位讲解
  • 原文地址:https://www.cnblogs.com/ranjiewen/p/10128046.html
Copyright © 2011-2022 走看看