zoukankan      html  css  js  c++  java
  • PyTorch源码解读之torch.utils.data.DataLoader(转)

    原文链接

    https://blog.csdn.net/u014380165/article/details/79058479

    写得特别好!最近正好在学习pytorch,学习一下!

    PyTorch中数据读取的一个重要接口是torch.utils.data.DataLoader,该接口定义在dataloader.py脚本中,只要是用PyTorch来训练模型基本都会用到该接口,该接口主要用来将自定义的数据读取接口的输出或者PyTorch已有的数据读取接口的输入按照batch size封装成Tensor,后续只需要再包装成Variable即可作为模型的输入,因此该接口有点承上启下的作用,比较重要。这篇博客介绍该接口的源码,主要包含DataLoader和DataLoaderIter两个类
    dataloader.py脚本的的github地址https://github.com/pytorch/pytorch/blob/master/torch/utils/data/dataloader.py

    DataLoader类源码如下。先看看__init__中的几个重要的输入:1、dataset,这个就是PyTorch已有的数据读取接口(比如torchvision.datasets.ImageFolder)或者自定义的数据接口的输出,该输出要么是torch.utils.data.Dataset类的对象,要么是继承自torch.utils.data.Dataset类的自定义类的对象。2、batch_size,根据具体情况设置即可。3、shuffle,一般在训练数据中会采用。4、collate_fn,是用来处理不同情况下的输入dataset的封装,一般采用默认即可,除非你自定义的数据读取输出非常少见。5、batch_sampler,从注释可以看出,其和batch_size、shuffle等参数是互斥的,一般采用默认。6、sampler,从代码可以看出,其和shuffle是互斥的,一般默认即可。7、num_workers,从注释可以看出这个参数必须大于等于0,0的话表示数据导入在主进程中进行,其他大于0的数表示通过多个进程来导入数据,可以加快数据导入速度。8、pin_memory,注释写得很清楚了: pin_memory (bool, optional): If True, the data loader will copy tensors into CUDA pinned memory before returning them. 也就是一个数据拷贝的问题。9、timeout,是用来设置数据读取的超时时间的,但超过这个时间还没读取到数据的话就会报错。
    __init__中,RandomSampler类表示随机采样且不重复,所以起到的就是shuffle的作用。BatchSampler类则是把batch size个RandomSampler类对象封装成一个,这样就实现了随机选取一个batch的目的。这两个采样类都是定义在sampler.py脚本中,地址:https://github.com/pytorch/pytorch/blob/master/torch/utils/data/sampler.py。以上这些都是初始化的时候进行的。当代码运行到要从torch.utils.data.DataLoader类生成的对象中取数据的时候,比如:
    train_data=torch.utils.data.DataLoader(...)
    for i, (input, target) in enumerate(train_data):
    ...
    就会调用DataLoader类的__iter__方法,__iter__方法就一行代码:return DataLoaderIter(self),输入正是DataLoader类的属性。因此当调用__iter__方法的时候就牵扯到另外一个类:DataLoaderIter,接下来介绍。

    class DataLoader(object):
    """
        Data loader. Combines a dataset and a sampler, and provides
        single- or multi-process iterators over the dataset.
    
        Arguments:
            dataset (Dataset): dataset from which to load the data.
            batch_size (int, optional): how many samples per batch to load
                (default: 1).
            shuffle (bool, optional): set to ``True`` to have the data reshuffled
                at every epoch (default: False).
            sampler (Sampler, optional): defines the strategy to draw samples from
                the dataset. If specified, ``shuffle`` must be False.
            batch_sampler (Sampler, optional): like sampler, but returns a batch of
                indices at a time. Mutually exclusive with batch_size, shuffle,
                sampler, and drop_last.
            num_workers (int, optional): how many subprocesses to use for data
                loading. 0 means that the data will be loaded in the main process.
                (default: 0)
            collate_fn (callable, optional): merges a list of samples to form a mini-batch.
            pin_memory (bool, optional): If ``True``, the data loader will copy tensors
                into CUDA pinned memory before returning them.
            drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
                if the dataset size is not divisible by the batch size. If ``False`` and
                the size of dataset is not divisible by the batch size, then the last batch
                will be smaller. (default: False)
            timeout (numeric, optional): if positive, the timeout value for collecting a batch
                from workers. Should always be non-negative. (default: 0)
            worker_init_fn (callable, optional): If not None, this will be called on each
                worker subprocess with the worker id as input, after seeding and before data
                loading. (default: None)
    """
    
        def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
                     num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False,
                     timeout=0, worker_init_fn=None):
            self.dataset = dataset
            self.batch_size = batch_size
            self.num_workers = num_workers
            self.collate_fn = collate_fn
            self.pin_memory = pin_memory
            self.drop_last = drop_last
            self.timeout = timeout
            self.worker_init_fn = worker_init_fn
    
            if timeout < 0:
                raise ValueError('timeout option should be non-negative')
    
            if batch_sampler is not None:
                if batch_size > 1 or shuffle or sampler is not None or drop_last:
                    raise ValueError('batch_sampler is mutually exclusive with '
                                     'batch_size, shuffle, sampler, and drop_last')
    
            if sampler is not None and shuffle:
                raise ValueError('sampler is mutually exclusive with shuffle')
    
            if self.num_workers < 0:
                raise ValueError('num_workers cannot be negative; '
                                 'use num_workers=0 to disable multiprocessing.')
    
            if batch_sampler is None:
                if sampler is None:
                    if shuffle:
                        sampler = RandomSampler(dataset)
                    else:
                        sampler = SequentialSampler(dataset)
                batch_sampler = BatchSampler(sampler, batch_size, drop_last)
    
            self.sampler = sampler
            self.batch_sampler = batch_sampler
    
        def __iter__(self):
            return DataLoaderIter(self)
    
        def __len__(self):
            return len(self.batch_sampler)

    DataLoaderIter类源码如下。self.index_queue = multiprocessing.SimpleQueue()中的multiprocessing是Python中的多进程管理包,而threading则是Python中的多线程管理包,二者很大一部分的接口用法类似。还是照例先看看__init__,前面部分都是一些赋值操作,比较特殊的是self.sample_iter = iter(self.batch_sampler),得到的self.sample_iter可以通过next(self.sample_iter)来获取batch size个数据的index。self.rcvd_idx表示读取到的一个batch数据的index,初始化为0,该值在迭代读取数据的时候会用到。if self.num_workers语句是针对多进程或单进程的情况进行初始化,如果不是设置为多进程读取数据,那么就不需要这些初始化操作,后面会介绍单进程数据读取。在if语句中通过multiprocessing.SimpleQueue()类创建了一个简单的队列对象。multiprocessing.Process类就是构造进程的类,这里根据设定的进程数来启动,然后赋值给self.workers。接下来的一个for循环就通过调用start方法依次启动self.workers中的进程。接下来关于self.pin_memory的判断语句,该判断语句内部主要是实现了多线程操作。self.pin_memory的含义在前面已经介绍过了,当为True的时候,就会把数据拷到CUDA中。self.data_queue = queue.Queue()是通过Python的queue模块初始化得到一个先进先出的队列(queue模块也可以初始化得到先进后出的队列,需要用queue.LifoQueue()初始化),queue模块主要应用在多线程读取数据中。在threading.Thread的args参数中,第一个参数in_data就是一个进程的数据,一个进程中不同线程的数据也是通过队列来维护的,这里采用的是Python的queue模块来初始化得到一个队列:queue.Queue()。初始化结束后,就会调用__next__方法,接下来介绍。
    总的来说,如果设置为多进程读取数据,那么就会采用队列的方式来读,如果不是采用多进程来读取数据,那就采用普通方式来读

    class DataLoaderIter(object):
        "Iterates once over the DataLoader's dataset, as specified by the sampler"
    
        def __init__(self, loader):
            self.dataset = loader.dataset
            self.collate_fn = loader.collate_fn
            self.batch_sampler = loader.batch_sampler
            self.num_workers = loader.num_workers
            self.pin_memory = loader.pin_memory and torch.cuda.is_available()
            self.timeout = loader.timeout
            self.done_event = threading.Event()
    
            self.sample_iter = iter(self.batch_sampler)
    
            if self.num_workers > 0:
                self.worker_init_fn = loader.worker_init_fn
                self.index_queue = multiprocessing.SimpleQueue()
                self.worker_result_queue = multiprocessing.SimpleQueue()
                self.batches_outstanding = 0
                self.worker_pids_set = False
                self.shutdown = False
                self.send_idx = 0
                self.rcvd_idx = 0
                self.reorder_dict = {}
    
                base_seed = torch.LongTensor(1).random_()[0]
                self.workers = [
                    multiprocessing.Process(
                        target=_worker_loop,
                        args=(self.dataset, self.index_queue, self.worker_result_queue, self.collate_fn,
                              base_seed + i, self.worker_init_fn, i))
                    for i in range(self.num_workers)]
    
                if self.pin_memory or self.timeout > 0:
                    self.data_queue = queue.Queue()
                    self.worker_manager_thread = threading.Thread(
                        target=_worker_manager_loop,
                        args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory,
                              torch.cuda.current_device()))
                    self.worker_manager_thread.daemon = True
                    self.worker_manager_thread.start()
                else:
                    self.data_queue = self.worker_result_queue
    
                for w in self.workers:
                    w.daemon = True  # ensure that the worker exits on process exit
                    w.start()
    
                _update_worker_pids(id(self), tuple(w.pid for w in self.workers))
                _set_SIGCHLD_handler()
                self.worker_pids_set = True
    
                # prime the prefetch loop
                for _ in range(2 * self.num_workers):
                    self._put_indices()

    DataLoaderIter类的__next__方法如下,包含3个if语句和1个while语句。
    第一个if语句是用来处理self.num_workers等于0的情况,也就是不采用多进程进行数据读取,可以看出在这个if语句中先通过indices = next(self.sample_iter)获取长度为batch size的列表:indices,这个列表的每个值表示一个batch中每个数据的index,每执行一次next操作都会读取一批长度为batch size的indices列表。然后通过self.collate_fn函数将batch size个tuple(每个tuple长度为2,其中第一个值是数据,Tensor类型,第二个值是标签,int类型)封装成一个list,这个list长度为2,两个值都是Tensor,一个是batch size个数据组成的FloatTensor,另一个是batch size个标签组成的LongTensor。所以简单讲self.collate_fn函数就是将batch size个分散的Tensor封装成一个Tensor。batch = pin_memory_batch(batch)中pin_memory_batch函数的作用就是将输入batch的每个Tensor都拷贝到CUDA中,该函数后面会详细介绍。
    第二个if语句是判断当前想要读取的batch的index(self.rcvd_idx)是否之前已经读出来过(已读出来的index和batch数据保存在self.reorder_dict字典中,可以结合最后的while语句一起看,因为self.reorder_dict字典的更新是在最后的while语句中),如果之前已经读取过了,就根据这个index从reorder_dict字典中弹出对应的数据。最后返回batch数据的时候是 return self._process_next_batch(batch),该方法后面会详细介绍。主要做是获取下一个batch的数据index信息。
    第三个if语句,self.batches_outstanding的值在前面初始中调用self._put_indices()方法时修改了,所以假设你的进程数self.num_workers设置为3,那么这里self.batches_outstanding就是3*2=6,可具体看self._put_indices()方法。
    最后的while循环就是真正用来从队列中读取数据的操作,最主要的就是idx, batch = self._get_batch(),通过调用_get_batch()方法来读取,后面有介绍,简单讲就是调用了队列的get方法得到下一个batch的数据,得到的batch一般是长度为2的列表,列表的两个值都是Tensor,分别表示数据(是一个batch的)和标签。_get_batch()方法除了返回batch数据外,还得到另一个输出:idx,这个输出表示batch的index,这个if idx != self.rcvd_idx条件语句表示如果你读取到的batch的index不等于当前想要的index:selg,rcvd_idx,那么就将读取到的数据保存在字典self.reorder_dict中:self.reorder_dict[idx] = batch,然后继续读取数据,直到读取到的数据的index等于self.rcvd_idx。

        def __next__(self):
            if self.num_workers == 0:  # same-process loading
                indices = next(self.sample_iter)  # may raise StopIteration
                batch = self.collate_fn([self.dataset[i] for i in indices])
                if self.pin_memory:
                    batch = pin_memory_batch(batch)
                return batch
    
            # check if the next sample has already been generated
            if self.rcvd_idx in self.reorder_dict:
                batch = self.reorder_dict.pop(self.rcvd_idx)
                return self._process_next_batch(batch)
    
            if self.batches_outstanding == 0:
                self._shutdown_workers()
                raise StopIteration
    
            while True:
                assert (not self.shutdown and self.batches_outstanding > 0)
                idx, batch = self._get_batch()
                self.batches_outstanding -= 1
                if idx != self.rcvd_idx:
                    # store out-of-order samples
                    self.reorder_dict[idx] = batch
                    continue
                return self._process_next_batch(batch)

    pin_memory_batch函数不是定义在DataLoader类或DataLoaderIter类中。该函数主要是对batch中的Tensor执行batch.pin_memory()操作,这里的很多条件语句只是用来判断batch的类型,假如batch是一个列表,列表中的每个值是Tensor,那么就会执行 elif isinstance(batch, collections.Sequence):这个条件,从而遍历该列表中的每个Tensor,然后执行第一个条件语句的内容: return batch.pin_memory()

    def pin_memory_batch(batch):
        if torch.is_tensor(batch):
            return batch.pin_memory()
        elif isinstance(batch, string_classes):
            return batch
        elif isinstance(batch, collections.Mapping):
            return {k: pin_memory_batch(sample) for k, sample in batch.items()}
        elif isinstance(batch, collections.Sequence):
            return [pin_memory_batch(sample) for sample in batch]
        else:
            return batch
    iter() 与 next()的调用方法
    dataiter = iter(trainloader) # 每次迭代取的是一个batch images, labels = dataiter.next() # 如果batch_size为4,则取出来的images是4×c×h×w的tensor,labels是1×4的向量

    DataloaderIter类的_get_batch方法。主要根据是否设置了超时时间来操作,如果超过指定的超时时间后没有从队列中读到数据就报错,如果不设置超时时间且一致没有从队列中读到数据,那么就会一直卡着且不报错,这部分是PyTorch后来修的一个bug。

        def _get_batch(self):
            if self.timeout > 0:
                try:
                    return self.data_queue.get(True, self.timeout)
                except queue.Empty:
                    raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout))
            else:
                return self.data_queue.get()

    DataLoaderIter类的_process_next_batch方法。首先对self.rcvd_idx进行加一,也就是更新下下一个要读取的batch数据的index。然后调用_put_indices()方法获取下一个batch的每个数据的index。

       def _process_next_batch(self, batch):
            self.rcvd_idx += 1
            self._put_indices()
            if isinstance(batch, ExceptionWrapper):
                raise batch.exc_type(batch.exc_msg)
            return batch

    DataLoaderIter类的_put_indices方法。该方法主要实现从self.sample_iter中读取下一个batch数据中每个数据的index:indices = next(self.sample_iter, None),注意这里的index和前面idx是不一样的,这里的index是一个batch中每个数据的index,idx是一个batch的index;然后将读取到的index通过调用queue对象的put方法压到队列self.index_queue中:self.index_queue.put((self.send_idx, indices))

    def _put_indices(self):
            assert self.batches_outstanding < 2 * self.num_workers
            indices = next(self.sample_iter, None)
            if indices is None:
                return
            self.index_queue.put((self.send_idx, indices))
            self.batches_outstanding += 1
            self.send_idx += 1
  • 相关阅读:
    fastx_toolkit软件使用说明
    转:bwa的使用方法
    转: Annovar 软件注释流程介绍
    转:linux下bwa和samtools的安装与使用
    GFF3格式
    漂浮广告窗
    CPU指令集设计RISC和CISC
    南方
    成都 夹三品
    微嵌1
  • 原文地址:https://www.cnblogs.com/kk17/p/9949387.html
Copyright © 2011-2022 走看看