zoukankan      html  css  js  c++  java
  • PyTorch学习记录003-Dataset和DataLoader

    1.utils.data包括Dataset和DataLoader

      torch.utils.data.Dataset为抽象类,表示Dataset的抽象类,所有其他数据集都应该进行子类化,所有子类应该override,__len__和__getitem__,前者提供了数据集的大小,后者支持整数索引,范围从0到len(self)。
      自定义数据集需要继承这个类,并实现两个函数,一个是__len__,另一个是__getitem__前者提供数据的大小(size),后者通过给定索引获取数据和标签__getitem__一次只能获取一个数据,所以需要通过torch.utils.data.DataLoader来定义一个新的迭代器,实现batch读取。
      首先定义获取数据集的类,该类继承基类Dataset,自定义一个数据集及对应标签。
    
    class TestDataset(data.Dataset): # 继承Dataset
        def __init__(self):
            # 一些由2维向量表示的数据集
            self.Data = np.asarray([[1,2],[3,4],[2,1],[3,4],[4,5]]) 
            # 这些是数据集对应的标签
            self.Label = np.asarray([0,1,0,1,2])
            
        def __getitem__(self, index):
            # 把numpy转换为tensor
            txt = torch.from_numpy(self.Data[index])
            label = torch.tensor(self.Label[index])
            return txt, label
        
        def __len__(self):
            return len(self.Data)
    
    
    Test = TestDataset()
    print(Test[2]) # 相当于调用__getitem__(2)
    print(Test.__len__())
    

    输出:

    (tensor([2, 1], dtype=torch.int32), tensor(0, dtype=torch.int32))
    5
    
      以上数据以tuple返回,每次只返回一个样本。实际上,Dateset只负责数据的抽取,调用一次__getitem__只返回一个样本。如果希望批量处理(batch),还要同时进行shuffle和并行加速等操作,可选择DataLoader。
    

    DataLoader的格式为:

    data.DataLoader(
    	dataset,                # 加载的数据集
    	batch_size=1,			# 批大小
    	shuffle=False,  		# 是否将数据打乱
    	sampler=None,			# 样本抽样
    	batch_sampler=None,
    	num_workers=0,			# 使用多进程加载的进程数,0代表不适用多进程
    	collate_fn=<function *>	# 如何将多个样本数据拼成一个batch
    	pin_memory=False,		# 是否将数据保存在pin memory中,pin memory中的数据转到GPU会快一些
    	drop_last=False,		# dataset中的数据个数可能不是batch_size的整数倍,drop_last为true会将多出来不足一个batch的数据丢弃
    	timeout=0,
    	worker_init_fn=None,
    )
    
    

    创建一个DataLoader:

    Test = TestDataset()
    test_loader = data.DataLoader(Test, batch_size = 2, 
    				    	shuffle = False, 
    				    	num_workers=2, 
    				    	drop_last = True)
    for i, traindata in enumerate(test_loader):
        print('i:{}'.format(i))
        Data, Label = traindata
        print('data:',Data)
        print('Label:', Label)
    

    输出:

    i:0
    data: tensor([[1, 2],
            [3, 4]], dtype=torch.int32)
    Label: tensor([0, 1], dtype=torch.int32)
    i:1
    data: tensor([[2, 1],
            [3, 4]], dtype=torch.int32)
    Label: tensor([0, 1], dtype=torch.int32)
    
      从这个结果可以看出,这是批量读取。我们可以像使用迭代器一样使用它,比如对它进行循环操作。不过由于它不是迭代器,我们可以通过iter命令将其转换为迭代器。
    
    dataiter = iter(test_loader)
    imgs,labels = next(dataiter)
    
      一般用data.Dataset处理同一个目录下的数据。如果数据在不同目录下,因为不同的目录代表不同类别(这种情况比较普遍),使用data.Dataset来处理就很不方便。不过,使用PyTorch另一种可视化数据处理工具(即torchvision)就非常方便,不但可以自动获取标签,还提供很多数据预处理、数据增强等转换函数。
  • 相关阅读:
    css3 flex 布局
    用CSS3 & jQuery创建apple TV海报视差效果
    JavaScript知识点的总结
    javascript 常用DOM操作整理
    html打造动画【系列4】哆啦A梦
    如何掌握jQuery插件开发(高能)
    前端基础进阶(一):内存空间详细图解
    JavaScript中数组对象详解
    [zhuan]JNIEnv解析
    在 C Level 用 dlopen 使用 第三方的 Shared Library (.so)
  • 原文地址:https://www.cnblogs.com/wzdszh/p/14109442.html
Copyright © 2011-2022 走看看