zoukankan      html  css  js  c++  java
  • Pytorch的DataLoader, DataSet, Sampler之间的关系

    本文参考: https://www.cnblogs.com/marsggbo/p/11308889.html

    pytorch 数据加载到模型的流程

    pytorch 的数据加载到模型的操作顺序是这样的:
    ① 创建一个 Dataset 对象
    ② 创建一个 DataLoader 对象
    ③ 循环这个 DataLoader 对象,将img, label加载到模型中进行训练

    dataset = MyDataset()  #创建一个dataset对象
    dataloader = DataLoader(dataset) #把dataset传入dataloader中得到一个dataloader对象
    num_epoches = 100
    for epoch in range(num_epoches): 
        for img, label in dataloader:
            ....
    

    DataLoader 函数及其说明

    DataLoader将自定义的Dataset根据batch size大小、是否shuffle等封装成一个Batch Size大小的Tensor,用于后面的训练。

    DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
    batch_sampler=None, num_workers=0, collate_fn=None,
    pin_memory=False, drop_last=False, timeout=0,
    worker_init_fn=None)
    • dataset 传入的数据集
    • batch_size 每个batch有多少个样本
    • shuffle 在每个epoch开始的时候,打乱数据
    • sampler
      自定义从样本中取数据的策略,如果指定这个参数,则shuffle必须为false
    • batch_sampler
      与sampler类似,但是一次只返回一个batch的indices(索引),需要注意的是,一旦指定了这个参数,那么batch_size,shuffle,sampler,drop_last就不能再指定了
    • num_workers
      这个参数决定了有几个进程来处理data loading。0意味着所有的数据都会被load进主进程。(默认为0)
    • collate_fn 将一个list的sample组成一个mini-batch的函数
    • pin_memory
      如果设置为True,那么data loader将会在返回它们之前,将tensors拷贝到CUDA中的固定内存(CUDA pinned memory)中
      pin_memory就是锁页内存,创建设置True,则意味着生成的Tensor数据最开始是属于内存中的锁页内存,这样将内存的Tensor转义到GPU的显存就会更快一些。
      主机中的内存,有两种存在方式,一是锁页,二是不锁页,锁页内存存放的内容在任何情况下都不会与主机的虚拟内存进行交换(注:虚拟内存就是硬盘),而不锁页内存在主机内存不足时,数据会存放在虚拟内存中。而显卡中的显存全部是锁页内存!
      当计算机的内存充足的时候,可以设置pin_memory=True。当系统卡住,或者交换内存使用过多的时候,设置pin_memory=False。因为pin_memory与电脑硬件性能有关,pytorch开发者不能确保每一个炼丹玩家都有高端设备,因此pin_memory默认为False。
    • drop_last
      如果设置为True:这个是对最后的未完成的batch来说的,比如你的batch_size设置为64,而一个epoch只有100个样本,那么训练的时候后面的36个就被扔掉了…
      如果为False(默认),那么会继续正常执行,只是最后的batch_size会小一点。
    • timeout
      如果是正数,表明等待从worker进程中收集一个batch等待的时间,若超出设定的时间还没有收集到,那就不收集这个内容了。这个numeric应总是大于等于0。默认为0

    DataLoader, DataSet, Sampler之间的关系

    DataLoader

      首先我们看一下DataLoader.next的源代码长什么样,为方便理解我只选取了num_works为0的情况(num_works简单理解就是能够并行化地读取数据)。

    class DataLoader(object):
    	...
        def __next__(self):
            if self.num_workers == 0:  
                indices = next(self.sample_iter)  # Sampler
                batch = self.collate_fn([self.dataset[i] for i in indices]) # Dataset
                if self.pin_memory:
                    batch = _utils.pin_memory.pin_memory_batch(batch)
                return batch
    

      在阅读上面代码前,我们可以假设我们的数据是一组图像,每一张图像对应一个index,那么如果我们要读取数据就只需要对应的index即可,即上面代码中的indices,而选取index的方式有多种,有按顺序的,也有乱序的,所以这个工作需要Sampler完成,现在你不需要具体的细节,后面会介绍,你只需要知道DataLoader和Sampler在这里产生关系。

    那么Dataset和DataLoader在什么时候产生关系呢?没错就是下面一行。我们已经拿到了indices,那么下一步我们只需要根据index对数据进行读取即可了。

    再下面的if语句的作用简单理解就是,如果pin_memory=True,那么Pytorch会采取一系列操作把数据拷贝到GPU,总之就是为了加速。

    综上可以知道DataLoader,Sampler和Dataset三者关系如下:

    Sampler

    要更加细致地理解Sampler原理,我们需要先阅读一下DataLoader 的源代码,如下:

    class DataLoader(object):
        def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
                     batch_sampler=None, num_workers=0, collate_fn=default_collate,
                     pin_memory=False, drop_last=False, timeout=0,
                     worker_init_fn=None)

      可以看到初始化参数里有两种sampler:samplerbatch_sampler,都默认为None。前者的作用是生成一系列的index,而batch_sampler则是将sampler生成的indices打包分组,得到一个又一个batch的index。例如下面示例中,BatchSamplerSequentialSampler生成的index按照指定的batch size分组。

    >>>in : list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
    >>>out: [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]

    Pytorch中已经实现的Sampler有如下几种:

    • SequentialSampler
    • RandomSampler
    • WeightedSampler
    • SubsetRandomSampler

    需要注意的是DataLoader的部分初始化参数之间存在互斥关系,这个你可以通过阅读源码更深地理解,这里只做总结:

    • 如果你自定义了batch_sampler,那么这些参数都必须使用默认值:batch_sizeshuffle,sampler,drop_last.
    • 如果你自定义了sampler,那么shuffle需要设置为False
    • 如果samplerbatch_sampler都为None,那么batch_sampler使用Pytorch已经实现好的BatchSampler,而sampler分两种情况:
      • shuffle=True,则sampler=RandomSampler(dataset)
      • shuffle=False,则sampler=SequentialSampler(dataset)

    Dataset 

    Dataset定义方式如下:

    class Dataset(object):
    	def __init__(self):
    		...
    	def __getitem__(self, index):
    		return ...
    	def __len__(self):
    		return ...
    

      上面三个方法是最基本的,其中__getitem__是最主要的方法,它规定了如何读取数据。但是它又不同于一般的方法,因为它是python built-in方法,其主要作用是能让该类可以像list一样通过索引值对数据进行访问。假如你定义好了一个dataset,那么你可以直接通过dataset[0]来访问第一个数据。在此之前我一直没弄清楚__getitem__是什么作用,所以一直不知道该怎么进入到这个函数进行调试。现在如果你想对__getitem__方法进行调试,你可以写一个for循环遍历dataset来进行调试了,而不用构建dataloader等一大堆东西了,建议学会使用ipdb这个库,非常实用!!!以后有时间再写一篇ipdb的使用教程。另外,其实我们通过最前面的Dataloader的__next__函数可以看到DataLoader对数据的读取其实就是用了for循环来遍历数据,不用往上翻了,我直接复制了一遍,如下:

    class DataLoader(object): 
        ... 
        def __next__(self): 
            if self.num_workers == 0:   
                indices = next(self.sample_iter)  
                batch = self.collate_fn([self.dataset[i] for i in indices]) # this line 
                if self.pin_memory: 
                    batch = _utils.pin_memory.pin_memory_batch(batch) 
                return batch
    

    我们仔细看可以发现,前面还有一个self.collate_fn方法,这个是干嘛用的呢?在介绍前我们需要知道每个参数的意义:

    • indices: 表示每一个iteration,sampler返回的indices,即一个batch size大小的索引列表
    • self.dataset[i]: 前面已经介绍了,这里就是对第i个数据进行读取操作,一般来说self.dataset[i]=(img, label)

    看到这不难猜出collate_fn的作用就是将一个batch的数据进行合并操作。默认的collate_fn是将img和label分别合并成imgs和labels,所以如果你的__getitem__方法只是返回 img, label,那么你可以使用默认的collate_fn方法,但是如果你每次读取的数据有img, box, label等等,那么你就需要自定义collate_fn来将对应的数据合并成一个batch数据,这样方便后续的训练步骤。

  • 相关阅读:
    smtp发送邮件
    鼠标点击成烟花js代码
    使用Database Control访问数据库问题解决了
    ext grid 的每行最后一列添加 按钮
    jquery对下拉框的操作
    SQL Server 2005中DateTime类型转换为Varchar类型的所有格式
    winform安装项目、安装包的制作、部署
    js解释器rhino查看执行环境
    ecma2623执行环境练习
    javascript排序算法
  • 原文地址:https://www.cnblogs.com/E-Dreamer-Blogs/p/13770045.html
Copyright © 2011-2022 走看看