zoukankan      html  css  js  c++  java
  • Torch的Dataloader类源代码以及简单解析

    Torch的Dataloader类
    import torch
    import torch.multiprocessing as multiprocessing
    from . import SequentialSampler, RandomSampler, BatchSampler
    from . import _utils
    import threading
    from torch._six import queue
    
    
    default_collate = _utils.collate.default_collate
    
    class DataLoader(object):
    
        __initialized = False
    
        def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
                     batch_sampler=None, num_workers=0, collate_fn=default_collate,
                     pin_memory=False, drop_last=False, timeout=0,
                     worker_init_fn=None):
            self.dataset = dataset
            self.batch_size = batch_size
            self.num_workers = num_workers
            self.collate_fn = collate_fn
            self.pin_memory = pin_memory
            self.drop_last = drop_last
            self.timeout = timeout
            self.worker_init_fn = worker_init_fn
    
            if timeout < 0:
                raise ValueError('timeout option should be non-negative')
    
            if batch_sampler is not None: # 有batch_sampler之后,其他的什么东西都不能要了
                if batch_size > 1 or shuffle or sampler is not None or drop_last:
                    raise ValueError('batch_sampler option is mutually exclusive '
                                     'with batch_size, shuffle, sampler, and drop_last')
                self.batch_size = None
                self.drop_last = None
    
            if sampler is not None and shuffle: # sampler 和shuffle不能兼容
                raise ValueError('sampler option is mutually exclusive with '
                                 'shuffle')
    
            if self.num_workers < 0:
                raise ValueError('num_workers option cannot be negative; '
                                 'use num_workers=0 to disable multiprocessing.')
    
            if batch_sampler is None:
                if sampler is None:
                    if shuffle:
                        sampler = RandomSampler(dataset)
                    else:
                        sampler = SequentialSampler(dataset)
                batch_sampler = BatchSampler(sampler, batch_size, drop_last)
    
            self.sampler = sampler
            self.batch_sampler = batch_sampler
            self.__initialized = True
    
        def __setattr__(self, attr, val):
            if self.__initialized and attr in ('batch_size', 'sampler', 'drop_last'):
                raise ValueError('{} attribute should not be set after {} is '
                                 'initialized'.format(attr, self.__class__.__name__))
    
            super(DataLoader, self).__setattr__(attr, val)
    
        def __iter__(self):
            return _DataLoaderIter(self)
    
        def __len__(self):
            return len(self.batch_sampler)
    

    使用方法大致如下:

    for i, (input, target) in enumerate(train_data):
    

    主要是_DataloaderIter这个类比较重要。

    简单的来讲,有以下几点比较重要,或者说,比较不太容易懂。

    1. _ _ iter _ _() 和 _ _ next _ ()表示一个类是迭代器。 _ _ iter _ _()返回一个特殊的迭代器对象。
    2. Queue在使用的时候,当queue为空,queue.get()会阻塞,阻塞态的时候,如果其他进程/线程有get操作,本线程会被通知,然后get成功。当数据满了,queue.put会阻塞。
    3. 没有多线程的时候,batch = self.collate_fn([self.dataset[i] for i in indices]),使用index转化为data,数据。也就是(image,label)。
    4. 多线程的时候,为每一个线程创建index_queues。共享一个worker_result_queue数据队列。在_worker_loop中加载数据。
    class _DataLoaderIter(object):
        """Iterates once over the DataLoader's dataset, as specified by the sampler"""
        # NOTE [ Data Loader Multiprocessing Shutdown Logic ]
        # Our data model looks like this (queues are indicated with curly brackets):
        #
        #                main process                              ||
        #                     |                                    ||
        #               {index_queue}                              ||
        #                     |                                    ||
        #              worker processes                            ||     DATA
        #                     |                                    ||
        #            {worker_result_queue}                         ||     FLOW
        #                     |                                    ||
        #      pin_memory_thread of main process                   ||   DIRECTION
        #                     |                                    ||
        #               {data_queue}                               ||
        #                     |                                    ||
        #                data output                               /
        #
    
        def __init__(self, loader):
            self.dataset = loader.dataset
            self.collate_fn = loader.collate_fn
            self.batch_sampler = loader.batch_sampler
            self.num_workers = loader.num_workers
            self.pin_memory = loader.pin_memory and torch.cuda.is_available()
            self.timeout = loader.timeout
    
            self.sample_iter = iter(self.batch_sampler)
    
            base_seed = torch.LongTensor(1).random_().item()
    
            if self.num_workers > 0:
                self.worker_init_fn = loader.worker_init_fn
                self.worker_queue_idx = 0
                self.worker_result_queue = multiprocessing.Queue()
                self.batches_outstanding = 0
                self.worker_pids_set = False
                self.shutdown = False
                self.send_idx = 0
                self.rcvd_idx = 0
                self.reorder_dict = {}
                self.done_event = multiprocessing.Event()
    
                self.index_queues = []
                self.workers = []
                for i in range(self.num_workers): # 启动num_workers那么多个进程
                    index_queue = multiprocessing.Queue()
                    index_queue.cancel_join_thread()
                    w = multiprocessing.Process(
                        target=_utils.worker._worker_loop,# 目的是启动_worker_loop这个函数
                        args=(self.dataset, index_queue,
                              self.worker_result_queue, self.done_event,
                              self.collate_fn, base_seed + i,
                              self.worker_init_fn, i))# 把idx和samples放进了全局的worker_result_queue里面,这里的idx指的不是batch的indexes。就是用了多个线程,往worker_result_queue中填满了数据而已。
                    w.daemon = True
                    # NB: Process.start() 
                    w.start()
                    self.index_queues.append(index_queue)
                    self.workers.append(w)
    
                if self.pin_memory: # 貌似pin_memory的作用就是赋值一下tensor去GPU
                    self.data_queue = queue.Queue()
                    pin_memory_thread = threading.Thread(
                        target=_utils.pin_memory._pin_memory_loop,
                        args=(self.worker_result_queue, self.data_queue,
                              torch.cuda.current_device(), self.done_event))
                    pin_memory_thread.daemon = True
                    pin_memory_thread.start()
                    # Similar to workers (see comment above), we only register pin_memory_thread once it is started.
                    self.pin_memory_thread = pin_memory_thread
                else:
                    self.data_queue = self.worker_result_queue
    			
                # 这里不是很懂,设置pids
                _utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self.workers))
                _utils.signal_handling._set_SIGCHLD_handler()
                self.worker_pids_set = True
    
                # prime the prefetch loop
                for _ in range(2 * self.num_workers): # 为什么*2,表示不是很懂,这里相当于加载了2*num_workers个batch的数据。大概是说,初始化的时候,给定足量的数据在里面。
                    self._put_indices()
    
        def __len__(self):
            return len(self.batch_sampler)
    
        def _get_batch(self): # 从data_queue中取得数据
            if self.timeout > 0:
                try:
                    return self.data_queue.get(timeout=self.timeout) # 从data_queue中get数据
                except queue.Empty:
                    raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout))
            elif self.pin_memory:
                while self.pin_memory_thread.is_alive(): #先判断一下pin_memory的线程是否还活着
                    try:
                        return self.data_queue.get(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
                    except queue.Empty:
                        continue
                else:
                    # while condition is false, i.e., pin_memory_thread died.
                    raise RuntimeError('Pin memory thread exited unexpectedly')
            else:
                return self.data_queue.get()
    
        def __next__(self):
            if self.num_workers == 0:  # same-process loading
                indices = next(self.sample_iter)  # may raise StopIteration
                batch = self.collate_fn([self.dataset[i] for i in indices])
                if self.pin_memory:
                    batch = _utils.pin_memory.pin_memory_batch(batch)
                return batch
    
            # check if the next sample has already been generated
            # 这里,出现了的rcvd_idx可以用一个dict存起来。
            if self.rcvd_idx in self.reorder_dict:
                batch = self.reorder_dict.pop(self.rcvd_idx)
                return self._process_next_batch(batch)
    		
            # 在outstandings这个东西消耗完之后,就直接shutdown workers, raise StopIteration
            if self.batches_outstanding == 0:
                self._shutdown_workers()
                raise StopIteration
    
            while True:
                assert (not self.shutdown and self.batches_outstanding > 0)
                idx, batch = self._get_batch()
                self.batches_outstanding -= 1
                if idx != self.rcvd_idx: # 这里的机制就必须按照rcvd_idx的顺序来。
                    # store out-of-order samples
                    self.reorder_dict[idx] = batch
                    continue
                return self._process_next_batch(batch)
    
        next = __next__  # Python 2 compatibility
    
        def __iter__(self):
            return self
    
        def _put_indices(self):
            assert self.batches_outstanding < 2 * self.num_workers
            indices = next(self.sample_iter, None)
            if indices is None:
                return
            self.index_queues[self.worker_queue_idx].put((self.send_idx, indices))
            self.worker_queue_idx = (self.worker_queue_idx + 1) % self.num_workers
            self.batches_outstanding += 1
            self.send_idx += 1
    
        def _process_next_batch(self, batch):
            self.rcvd_idx += 1
            self._put_indices()
            if isinstance(batch, _utils.ExceptionWrapper):
                raise batch.exc_type(batch.exc_msg)
            return batch
    
        def __getstate__(self):
    		"""
    		TODO:为HogWild添加有限的picking支持,以便跨多个线程共享迭代器。
    			 最好的方法可能是将示例推送到单独的线程,然后只共享数据队列,
    			 但如果没有非阻塞的API,则发送结束信号是很困难的。
    		"""
            raise NotImplementedError("_DataLoaderIter cannot be pickled")
    
        def _shutdown_workers(self):
            # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the logic of this function.
            python_exit_status = _utils.python_exit_status
            if python_exit_status is True or python_exit_status is None:
                # See (2) of the note. If Python is shutting down, do no-op.
                return
            # Normal exit when last reference is gone / iterator is depleted. See (1) and the second half of the note.
            if not self.shutdown:
                self.shutdown = True
                # Removes pids from the C side data structure first so worker termination afterwards won't trigger false positive error report.
                if self.worker_pids_set:
                    _utils.signal_handling._remove_worker_pids(id(self))
                    self.worker_pids_set = False
    
                self.done_event.set()
    
                # Exit `pin_memory_thread` first because exiting workers may leave
                # corrupted data in `worker_result_queue` which `pin_memory_thread` reads from.
                if hasattr(self, 'pin_memory_thread'):
                    self.worker_result_queue.cancel_join_thread()
                    self.worker_result_queue.put(None)
                    self.pin_memory_thread.join()
                    self.worker_result_queue.close()
    
                # Exit workers now.
                for q in self.index_queues:
                    q.put(None)
                    # Indicate that no more data will be put on this queue by the current process.
                    q.close()
                for w in self.workers:
                    w.join()
    
        def __del__(self):
            if self.num_workers > 0:
                self._shutdown_workers()
    
    def _worker_loop(dataset, index_queue, data_queue, collate_fn, seed, init_fn, worker_id):
        global _use_shared_memory
        _use_shared_memory = True
    
        # Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
        # module's handlers are executed after Python returns from C low-level
        # handlers, likely when the same fatal signal happened again already.
        # https://docs.python.org/3/library/signal.html Sec. 18.8.1.1
        _set_worker_signal_handlers()
    
        torch.set_num_threads(1)
        random.seed(seed)
        torch.manual_seed(seed)
    
        if init_fn is not None: # 初始化worker
            init_fn(worker_id)
    
        watchdog = ManagerWatchdog()
    
        while True:
            try:
                r = index_queue.get(timeout=MANAGER_STATUS_CHECK_INTERVAL)
            except queue.Empty:
                if watchdog.is_alive():
                    continue
                else:
                    break
            if r is None:
                break
            idx, batch_indices = r
            try:
                samples = collate_fn([dataset[i] for i in batch_indices])
            except Exception:
                data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
            else:
                data_queue.put((idx, samples)) # 把idx和samples放进了全局的worker_result_queue里面
                del samples
    
  • 相关阅读:
    Android 主题theme说明 摘记
    Android开发 去掉标题栏方法 摘记
    安卓项目五子棋代码详解(二)
    关于 ake sure class name exists, is public, and has an empty constructor that is public
    百度地图3.0实现图文并茂的覆盖物
    android onSaveInstanceState()及其配对方法。
    关于集成科大讯飞语音识别的 一个问题总结
    android 关于 webview 控制其它view的显示 以及更改view数据失败的问题总结
    C# 解析 json Newtonsoft果然强大,代码写的真好
    c#数据类型 与sql的对应关系 以及 取值范围
  • 原文地址:https://www.cnblogs.com/JohnRan/p/15099683.html
Copyright © 2011-2022 走看看