zoukankan      html  css  js  c++  java
  • pytorch的dataset与dataloader解析

    整理一下pytorch获取的流程:

    1. 创建Dataset对象
    2. 创建DataLoader对象,装载有dataset对象
    3. 循环DataLoader对象,DataLoader.__iter__返回的是DataLoaderIter对象
    dataset = MyDataset()
    dataloader = DataLoader(dataset)
    num_epoches = 100
    for epoch in range(num_epoches):
        for data in dataloader:
            ....

    根据源码分析:torch.utils.data

    1 - Dataset:

    class Dataset(object):
        """An abstract class representing a Dataset.
        All other datasets should subclass it. All subclasses should override
        ``__len__``, that provides the size of the dataset, and ``__getitem__``,
        supporting integer indexing in range from 0 to len(self) exclusive.
        """
    
        def __getitem__(self, index):
            raise NotImplementedError
    
        def __len__(self):
            raise NotImplementedError
    
        def __add__(self, other):
            return ConcatDataset([self, other])
    

    Dataset这是一个抽象类,不能实例化,需要重写类方法,关键点有两个:

    • __getitem__ 这个很重要,规定了如何读数据,比如常用的transform
    • __len__ 这个就是返回数据集的长度,比如:return len(self.data)

    2 - DataLoader:

    DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
               batch_sampler=None, num_workers=0, collate_fn=None,
               pin_memory=False, drop_last=False, timeout=0,
               worker_init_fn=None)

    先看一下主要参数:

    • dataset:就是 torch.utils.data.Dataset 类的实例。也就是说为了使用 DataLoader 类,需要先定义一个 torch.utils.data.Dataset 类的实例。
    • batch_size:每一个批次需要加载的训练样本个数。
    • shuffle:如果设置为 True 表示训练样本数据会被随机打乱,默认值为 False。一般会设置为 True 。
    • sampler:自定义从数据集中取样本的策略,如果指定这个参数,那么 shuffle 必须为 False 。从源码中可以看到,如果指定了该参数,同时 shuffle 设定为 True,DataLoader 的 __init__ 函数就会抛出一个异常 。
    • batch_sampler:与 sampler 类似,但是一次只返回一个 batch 的 indices(索引),需要注意的是,一旦指定了这个参数,那么 batch_size,shuffle,sampler,drop_last 就不能再指定了。源码中同样做了限制。
    • num_workers:表示会使用多少个线程来加载训练数据;默认值为 0,表示数据加载直接在主线程中进行。
    • collate_fn:对每一个 batch 的数据做一些你想要的操作。一个例子,https://zhuanlan.zhihu.com/p/346332974
    • pin_memory:把数据转移到和 GPU 相关联的 CPU 内存,加速 GPU 载入数据的速度。
    • drop_last:比如你的batch_size设置为 32,而一个 epoch 只有 100 个样本;如果设置为 True,那么训练的时候后面的 4 个就被扔掉了。如果为 False(默认),那么会继续正常执行,只是最后的 batch_size 会小一点。
    • timeout:加载一个 batch 数据的超时时间。
    • worker_init_fn:指定每个数据加载线程的入口函数。

    源码分析:

    class DataLoader(object):
        def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, 
                     batch_sampler=None,
                     num_workers=0, collate_fn=default_collate, pin_memory=False, 
                     drop_last=False):
            self.dataset = dataset
            self.batch_size = batch_size
            self.num_workers = num_workers
            self.collate_fn = collate_fn
            self.pin_memory = pin_memory
            self.drop_last = drop_last
    
            if batch_sampler is not None:
                if batch_size > 1 or shuffle or sampler is not None or drop_last:
                    raise ValueError('batch_sampler is mutually exclusive with '
                                     'batch_size, shuffle, sampler, and drop_last')
    
            if sampler is not None and shuffle:
                raise ValueError('sampler is mutually exclusive with shuffle')
    
            if batch_sampler is None:
                if sampler is None:
                    if shuffle:
                        # dataset.__len__() 在 Sampler 中被使用。
                        # 目的是生成一个 长度为 len(dataset) 的 序列索引(随机的)。
                        sampler = RandomSampler(dataset)
                    else:
                        # dataset.__len__() 在 Sampler 中被使用。
                        # 目的是生成一个 长度为 len(dataset) 的 序列索引(顺序的)。
                        sampler = SequentialSampler(dataset)
                # Sampler 是个迭代器,一次之只返回一个 索引
                # BatchSampler 也是个迭代器,但是一次返回 batch_size 个 索引
                batch_sampler = BatchSampler(sampler, batch_size, drop_last)
    
            self.sampler = sampler
            self.batch_sampler = batch_sampler
    
        def __iter__(self):
            return DataLoaderIter(self)
    
        def __len__(self):
            return len(self.batch_sampler) 

    可以发现__iter__返回的是DataLoaderIter

    3 - DataLoaderIter

    先看init初始化:

    if self.num_workers > 0:
        self.worker_init_fn = loader.worker_init_fn
    # 定义了workers相同数量个Queue并放置在index_queues这个list中, # 这些Queue与worker一一对应,用来给worker传递“工作内容” self.index_queues = [multiprocessing.Queue() for _ in range(self.num_workers)]
    # worker_queue_idx用于下一个工作的workre序号,主进程轮询使用不同workers self.worker_queue_idx = 0
    # 各个workre将自己所取得的数据传递给wokrker_result_queue,供主进程fetch self.worker_result_queue = multiprocessing.SimpleQueue() # 记录当前时刻分配了多少个任务(可能有处于等待状态的任务) self.batches_outstanding = 0 self.worker_pids_set = False self.shutdown = False # 发送出去数据的编号 self.send_idx = 0 # 接受到数据的编号 self.rcvd_idx = 0 # 缓存区 self.reorder_dict = {} self.workers = [ multiprocessing.Process( target=_worker_loop, args=(self.dataset, self.index_queues[i], self.worker_result_queue, self.collate_fn, base_seed + i, self.worker_init_fn, i)) for i in range(self.num_workers)] # 初始化相应的进程,目标函数为_worker_loop # 参数:dataset(用于数据读取),index_queues[i]为worker对应的index_queue # 以及用于输出的queue # 此处主要用于数据读取后的pin_memory操作,不影响多进程主逻辑,暂不展开 if self.pin_memory or self.timeout > 0: ... else: self.data_queue = self.worker_result_queue for w in self.workers: w.daemon = True # ensure that the worker exits on process exit # 将父进程设置为守护进程,保证父进程结束后,worker进程也结束,必须设置在start之前 w.start() # 下面是一些系统信号处理逻辑,对这方面我还不太熟悉就不介绍了。 _update_worker_pids(id(self), tuple(w.pid for w in self.workers)) _set_SIGCHLD_handler() self.worker_pids_set = True # 初始化后生成2*num_workers数量个prefetch的数据,使dataloader提前工作,提升整体效率。 # prime the prefetch loop for _ in range(2 * self.num_workers): self._put_indices()

    init过程有两个函数,一个是worker_loop,另个是put_indices

    a. 先看worker_loop:

    def _worker_loop(dataset, index_queue, data_queue, collate_fn, seed, init_fn, worker_id):
        global _use_shared_memory
        _use_shared_memory = True
    
        # Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
        # module's handlers are executed after Python returns from C low-level
        # handlers, likely when the same fatal signal happened again already.
        # https://docs.python.org/3/library/signal.html Sec. 18.8.1.1
        _set_worker_signal_handlers()
    
        torch.set_num_threads(1)
        random.seed(seed)
        torch.manual_seed(seed)
    
        if init_fn is not None:
            init_fn(worker_id)
        
        # 父进程状态监测
        watchdog = ManagerWatchdog()
        
        # 死循环查询是否有任务传进来
        while True:
            try:
                # 从index_queue获取相应数据
                r = index_queue.get(timeout=MANAGER_STATUS_CHECK_INTERVAL)
            except queue.Empty:
                if watchdog.is_alive():
                    continue
                else:
                    break
            if r is None:
                break
            idx, batch_indices = r
            try:
                # 获得以后for循环进行读取数据读取,此处和单进程的工作原理是一样的
                # 因此时间花费和batchsize数量呈线性关系
                samples = collate_fn([dataset[i] for i in batch_indices])
                # 经过collate_fn后变成torch.Tensor
            except Exception:
                # 异常处理
                data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
            else:
                # 通过data_queue传回处理好的batch数据
                data_queue.put((idx, samples))
                # 显示删除中间变量,降低内存消耗
                del samples

    这里就是不停地轮询,从index_queues队列里获得索引,然后通过collate_fn函数和索引获取tensor,然后塞入data_queue

    b. 再看put_indices

    def _put_indices(self):
        assert self.batches_outstanding < 2 * self.num_workers
        # 默认设定是只允许分配2*num_workers个任务,保证内存等资源不被耗尽
        indices = next(self.sample_iter, None)
        # 从sample_iter中拿到dataset中下一轮次的索引,用于fetch数据
        if indices is None:
            return
        self.index_queues[self.worker_queue_idx].put((self.send_idx, indices))
        # 轮询选择worker,找到其对应的队列,向其中发送工作内容(数据编号,数据索引)
        self.worker_queue_idx = (self.worker_queue_idx + 1) % self.num_workers
        # worker_queue_idx自增
        self.batches_outstanding += 1
        # 任务分配数+1
        self.send_idx += 1
        # 已发送任务总数+1(下批数据编号) 

    这个就是把索引塞进队列index_queues

    以上就是init,当for循环时,会调用next:

    c. __next__返回一个batch

    def __next__(self):
        if self.num_workers == 0:  # same-process loading  (主进程阻塞式读取数据)
            indices = next(self.sample_iter)  # may raise StopIteration
            batch = self.collate_fn([self.dataset[i] for i in indices])
            if self.pin_memory:
                batch = pin_memory_batch(batch)
            return batch
        
        # check if the next sample has already been generated
        # 先查看数据是否在缓存dict中
        if self.rcvd_idx in self.reorder_dict:
            batch = self.reorder_dict.pop(self.rcvd_idx)
            return self._process_next_batch(batch)
        # 异常处理
        if self.batches_outstanding == 0:
            self._shutdown_workers()
            raise StopIteration
        while True:
            assert (not self.shutdown and self.batches_outstanding > 0)
            # 阻塞式的从data_queue里面获取处理好的批数据
            idx, batch = self._get_batch() 
            # 任务数减一
            self.batches_outstanding -= 1
            # 这一步可能会造成的周期阻塞现象
            # 每次获取data以后,要校验和rcvd_idx是否一致
            # 若不一致,则先把获取到的数据放到reorder_dict这个缓存dict中,继续死循环
            # 直到获取到相应的idx编号于rcvd_idx可以对应上,并将数据返回
            if idx != self.rcvd_idx:
                # store out-of-order samples
                self.reorder_dict[idx] = batch
                continue
            return self._process_next_batch(batch)

    __next__里的while True,要从data_queue里面读到的数据idx和rcvd_idx一致才将数据返回。因此可能会存在如下这种情况:

    假设num_workers=8,现在发送了8个数据给相应的worker,此时send_idx=8,rcvd_idx=0。过了一段时间以后,{1,2,3,5,6,7}进程数据准备完毕,此时主进程从data_queue读取到相关的数据,但由于和rcvd_idx不匹配,只能将其放在缓存里。直到send_idx=0数据准备齐以后,才能将数据返回出去,随后从缓存中弹出2,3的数据,之后又阻塞等待idx=4的数据。即输出的数据必须保持顺序性!因此在worker变多,出现这种逆序现象可能性会更大,这种现象也会出现在非num_workrers次迭代,只要相应的rcvd_idx没有得到相关数据,则主进程就会一直等待。

    d. process_next_batch

    def _process_next_batch(self, batch):
        # 序号对上以后,rcvd_idx自加1
        self.rcvd_idx += 1
        # 添加一个fetchdata任务给worker
        self._put_indices()
        if isinstance(batch, ExceptionWrapper):
            raise batch.exc_type(batch.exc_msg)
        return batch
    

      

    这个函数注意的是,只有在__next__中,idx == self.rcvd_idx时才会调用,也就是可能出现多个worker已经准备好了,但是只能放在缓存区,并且无法向index_queues塞入索引,使worker无法保持活跃状态。

    最后对于for循环从dataloader获取data总体流程:

    for epoch in range(num_epoches):
        for data in dataloader:

    对于这个for,其实就是调用了dataloader 的__iter__() 方法, 产生了一个DataLoaderIter,如果是num_worker>0,init里就会创建多线程,并且有两个队列,一个是存放dataset的索引index_queues,一个是从index_queues里拿到索引,调用dataset的__getitem__()方法 (如果num_worker>0就多线程调用), 然后用collate_fn来把它们打包成batch,放到data_queue队列里,反复调用DataLoaderIter 的__next__,从data_queue中获取batch。

    参考:

    Pytorch数据读取(Dataset, DataLoader, DataLoaderIter) https://zhuanlan.zhihu.com/p/30934236 

    PyTorch 之 Dataset 和 Dataloader https://zhuanlan.zhihu.com/p/339675188

    PyTorch36.DataLoader源代码剖析 https://zhuanlan.zhihu.com/p/169497395

    PyTorch DataLoader初探 https://zhuanlan.zhihu.com/p/91521705

    一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系 https://zhuanlan.zhihu.com/p/76893455

  • 相关阅读:
    windows下搭建hadoopproject(一)
    inspect模块---检查活动对象
    Python的datetime模块分析
    深入理解python之self
    request payload
    计算机基础知识
    pycharm常用快捷键
    英语学习五大法则
    基础语法
    英语基本语法
  • 原文地址:https://www.cnblogs.com/philo-zhou/p/14956459.html
Copyright © 2011-2022 走看看