zoukankan      html  css  js  c++  java
  • Torch的Dataloader类源代码以及简单解析

    Torch的Dataloader类
    import torch
    import torch.multiprocessing as multiprocessing
    from . import SequentialSampler, RandomSampler, BatchSampler
    from . import _utils
    import threading
    from torch._six import queue
    
    
    default_collate = _utils.collate.default_collate
    
    class DataLoader(object):
    
        __initialized = False
    
        def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
                     batch_sampler=None, num_workers=0, collate_fn=default_collate,
                     pin_memory=False, drop_last=False, timeout=0,
                     worker_init_fn=None):
            self.dataset = dataset
            self.batch_size = batch_size
            self.num_workers = num_workers
            self.collate_fn = collate_fn
            self.pin_memory = pin_memory
            self.drop_last = drop_last
            self.timeout = timeout
            self.worker_init_fn = worker_init_fn
    
            if timeout < 0:
                raise ValueError('timeout option should be non-negative')
    
            if batch_sampler is not None: # 有batch_sampler之后,其他的什么东西都不能要了
                if batch_size > 1 or shuffle or sampler is not None or drop_last:
                    raise ValueError('batch_sampler option is mutually exclusive '
                                     'with batch_size, shuffle, sampler, and drop_last')
                self.batch_size = None
                self.drop_last = None
    
            if sampler is not None and shuffle: # sampler 和shuffle不能兼容
                raise ValueError('sampler option is mutually exclusive with '
                                 'shuffle')
    
            if self.num_workers < 0:
                raise ValueError('num_workers option cannot be negative; '
                                 'use num_workers=0 to disable multiprocessing.')
    
            if batch_sampler is None:
                if sampler is None:
                    if shuffle:
                        sampler = RandomSampler(dataset)
                    else:
                        sampler = SequentialSampler(dataset)
                batch_sampler = BatchSampler(sampler, batch_size, drop_last)
    
            self.sampler = sampler
            self.batch_sampler = batch_sampler
            self.__initialized = True
    
        def __setattr__(self, attr, val):
            if self.__initialized and attr in ('batch_size', 'sampler', 'drop_last'):
                raise ValueError('{} attribute should not be set after {} is '
                                 'initialized'.format(attr, self.__class__.__name__))
    
            super(DataLoader, self).__setattr__(attr, val)
    
        def __iter__(self):
            return _DataLoaderIter(self)
    
        def __len__(self):
            return len(self.batch_sampler)
    

    使用方法大致如下:

    for i, (input, target) in enumerate(train_data):
    

    主要是_DataloaderIter这个类比较重要。

    简单的来讲,有以下几点比较重要,或者说,比较不太容易懂。

    1. _ _ iter _ _() 和 _ _ next _ ()表示一个类是迭代器。 _ _ iter _ _()返回一个特殊的迭代器对象。
    2. Queue在使用的时候,当queue为空,queue.get()会阻塞,阻塞态的时候,如果其他进程/线程有get操作,本线程会被通知,然后get成功。当数据满了,queue.put会阻塞。
    3. 没有多线程的时候,batch = self.collate_fn([self.dataset[i] for i in indices]),使用index转化为data,数据。也就是(image,label)。
    4. 多线程的时候,为每一个线程创建index_queues。共享一个worker_result_queue数据队列。在_worker_loop中加载数据。
    class _DataLoaderIter(object):
        """Iterates once over the DataLoader's dataset, as specified by the sampler"""
        # NOTE [ Data Loader Multiprocessing Shutdown Logic ]
        # Our data model looks like this (queues are indicated with curly brackets):
        #
        #                main process                              ||
        #                     |                                    ||
        #               {index_queue}                              ||
        #                     |                                    ||
        #              worker processes                            ||     DATA
        #                     |                                    ||
        #            {worker_result_queue}                         ||     FLOW
        #                     |                                    ||
        #      pin_memory_thread of main process                   ||   DIRECTION
        #                     |                                    ||
        #               {data_queue}                               ||
        #                     |                                    ||
        #                data output                               /
        #
    
        def __init__(self, loader):
            self.dataset = loader.dataset
            self.collate_fn = loader.collate_fn
            self.batch_sampler = loader.batch_sampler
            self.num_workers = loader.num_workers
            self.pin_memory = loader.pin_memory and torch.cuda.is_available()
            self.timeout = loader.timeout
    
            self.sample_iter = iter(self.batch_sampler)
    
            base_seed = torch.LongTensor(1).random_().item()
    
            if self.num_workers > 0:
                self.worker_init_fn = loader.worker_init_fn
                self.worker_queue_idx = 0
                self.worker_result_queue = multiprocessing.Queue()
                self.batches_outstanding = 0
                self.worker_pids_set = False
                self.shutdown = False
                self.send_idx = 0
                self.rcvd_idx = 0
                self.reorder_dict = {}
                self.done_event = multiprocessing.Event()
    
                self.index_queues = []
                self.workers = []
                for i in range(self.num_workers): # 启动num_workers那么多个进程
                    index_queue = multiprocessing.Queue()
                    index_queue.cancel_join_thread()
                    w = multiprocessing.Process(
                        target=_utils.worker._worker_loop,# 目的是启动_worker_loop这个函数
                        args=(self.dataset, index_queue,
                              self.worker_result_queue, self.done_event,
                              self.collate_fn, base_seed + i,
                              self.worker_init_fn, i))# 把idx和samples放进了全局的worker_result_queue里面,这里的idx指的不是batch的indexes。就是用了多个线程,往worker_result_queue中填满了数据而已。
                    w.daemon = True
                    # NB: Process.start() 
                    w.start()
                    self.index_queues.append(index_queue)
                    self.workers.append(w)
    
                if self.pin_memory: # 貌似pin_memory的作用就是赋值一下tensor去GPU
                    self.data_queue = queue.Queue()
                    pin_memory_thread = threading.Thread(
                        target=_utils.pin_memory._pin_memory_loop,
                        args=(self.worker_result_queue, self.data_queue,
                              torch.cuda.current_device(), self.done_event))
                    pin_memory_thread.daemon = True
                    pin_memory_thread.start()
                    # Similar to workers (see comment above), we only register pin_memory_thread once it is started.
                    self.pin_memory_thread = pin_memory_thread
                else:
                    self.data_queue = self.worker_result_queue
    			
                # 这里不是很懂,设置pids
                _utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self.workers))
                _utils.signal_handling._set_SIGCHLD_handler()
                self.worker_pids_set = True
    
                # prime the prefetch loop
                for _ in range(2 * self.num_workers): # 为什么*2,表示不是很懂,这里相当于加载了2*num_workers个batch的数据。大概是说,初始化的时候,给定足量的数据在里面。
                    self._put_indices()
    
        def __len__(self):
            return len(self.batch_sampler)
    
        def _get_batch(self): # 从data_queue中取得数据
            if self.timeout > 0:
                try:
                    return self.data_queue.get(timeout=self.timeout) # 从data_queue中get数据
                except queue.Empty:
                    raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout))
            elif self.pin_memory:
                while self.pin_memory_thread.is_alive(): #先判断一下pin_memory的线程是否还活着
                    try:
                        return self.data_queue.get(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
                    except queue.Empty:
                        continue
                else:
                    # while condition is false, i.e., pin_memory_thread died.
                    raise RuntimeError('Pin memory thread exited unexpectedly')
            else:
                return self.data_queue.get()
    
        def __next__(self):
            if self.num_workers == 0:  # same-process loading
                indices = next(self.sample_iter)  # may raise StopIteration
                batch = self.collate_fn([self.dataset[i] for i in indices])
                if self.pin_memory:
                    batch = _utils.pin_memory.pin_memory_batch(batch)
                return batch
    
            # check if the next sample has already been generated
            # 这里,出现了的rcvd_idx可以用一个dict存起来。
            if self.rcvd_idx in self.reorder_dict:
                batch = self.reorder_dict.pop(self.rcvd_idx)
                return self._process_next_batch(batch)
    		
            # 在outstandings这个东西消耗完之后,就直接shutdown workers, raise StopIteration
            if self.batches_outstanding == 0:
                self._shutdown_workers()
                raise StopIteration
    
            while True:
                assert (not self.shutdown and self.batches_outstanding > 0)
                idx, batch = self._get_batch()
                self.batches_outstanding -= 1
                if idx != self.rcvd_idx: # 这里的机制就必须按照rcvd_idx的顺序来。
                    # store out-of-order samples
                    self.reorder_dict[idx] = batch
                    continue
                return self._process_next_batch(batch)
    
        next = __next__  # Python 2 compatibility
    
        def __iter__(self):
            return self
    
        def _put_indices(self):
            assert self.batches_outstanding < 2 * self.num_workers
            indices = next(self.sample_iter, None)
            if indices is None:
                return
            self.index_queues[self.worker_queue_idx].put((self.send_idx, indices))
            self.worker_queue_idx = (self.worker_queue_idx + 1) % self.num_workers
            self.batches_outstanding += 1
            self.send_idx += 1
    
        def _process_next_batch(self, batch):
            self.rcvd_idx += 1
            self._put_indices()
            if isinstance(batch, _utils.ExceptionWrapper):
                raise batch.exc_type(batch.exc_msg)
            return batch
    
        def __getstate__(self):
    		"""
    		TODO:为HogWild添加有限的picking支持,以便跨多个线程共享迭代器。
    			 最好的方法可能是将示例推送到单独的线程,然后只共享数据队列,
    			 但如果没有非阻塞的API,则发送结束信号是很困难的。
    		"""
            raise NotImplementedError("_DataLoaderIter cannot be pickled")
    
        def _shutdown_workers(self):
            # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the logic of this function.
            python_exit_status = _utils.python_exit_status
            if python_exit_status is True or python_exit_status is None:
                # See (2) of the note. If Python is shutting down, do no-op.
                return
            # Normal exit when last reference is gone / iterator is depleted. See (1) and the second half of the note.
            if not self.shutdown:
                self.shutdown = True
                # Removes pids from the C side data structure first so worker termination afterwards won't trigger false positive error report.
                if self.worker_pids_set:
                    _utils.signal_handling._remove_worker_pids(id(self))
                    self.worker_pids_set = False
    
                self.done_event.set()
    
                # Exit `pin_memory_thread` first because exiting workers may leave
                # corrupted data in `worker_result_queue` which `pin_memory_thread` reads from.
                if hasattr(self, 'pin_memory_thread'):
                    self.worker_result_queue.cancel_join_thread()
                    self.worker_result_queue.put(None)
                    self.pin_memory_thread.join()
                    self.worker_result_queue.close()
    
                # Exit workers now.
                for q in self.index_queues:
                    q.put(None)
                    # Indicate that no more data will be put on this queue by the current process.
                    q.close()
                for w in self.workers:
                    w.join()
    
        def __del__(self):
            if self.num_workers > 0:
                self._shutdown_workers()
    
    def _worker_loop(dataset, index_queue, data_queue, collate_fn, seed, init_fn, worker_id):
        global _use_shared_memory
        _use_shared_memory = True
    
        # Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
        # module's handlers are executed after Python returns from C low-level
        # handlers, likely when the same fatal signal happened again already.
        # https://docs.python.org/3/library/signal.html Sec. 18.8.1.1
        _set_worker_signal_handlers()
    
        torch.set_num_threads(1)
        random.seed(seed)
        torch.manual_seed(seed)
    
        if init_fn is not None: # 初始化worker
            init_fn(worker_id)
    
        watchdog = ManagerWatchdog()
    
        while True:
            try:
                r = index_queue.get(timeout=MANAGER_STATUS_CHECK_INTERVAL)
            except queue.Empty:
                if watchdog.is_alive():
                    continue
                else:
                    break
            if r is None:
                break
            idx, batch_indices = r
            try:
                samples = collate_fn([dataset[i] for i in batch_indices])
            except Exception:
                data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
            else:
                data_queue.put((idx, samples)) # 把idx和samples放进了全局的worker_result_queue里面
                del samples
    
  • 相关阅读:
    JVM系列六(自定义插入式注解器).
    JVM系列五(Javac 字节码编译器).
    2019 — 求不得,放不下
    Mybatis 条件判断单双引号解析问题
    JVM系列四(对象分配策略).
    JVM系列三(垃圾收集器).
    Spring MVC -- Spring Tool Suite和Maven(安装Tomcat、JDK)
    Spring MVC -- 单元测试和集成测试
    Spring MVC -- 下载文件
    Spring MVC -- 上传文件
  • 原文地址:https://www.cnblogs.com/JohnRan/p/15099683.html
Copyright © 2011-2022 走看看