数据集对象
在 pytorch 中,数据加载可以通过自定义数据集对象实现;
数据集对象被抽象为 DataSet 类;
自定义数据集对象,需要继承该类,并且实现 __getitem__ 和 __len__ 两个方法
示例
class DogCat(data.Dataset): def __init__(self, path): images = os.listdir(path) self.imgs = [os.path.join(path, i) for i in images] def __getitem__(self, index): ### 只是简单读取图片,并未做任何处理 img_path = self.imgs[index] label = 1 if 'dog' in img_path.split('\')[-1] else 0 pil_img = Image.open(img_path) array = np.asarray(pil_img) data = t.from_numpy(array) return data, label def __len__(self): return len(self.imgs) dataSet = DogCat(r'F:dl_datasetcat_dog rain') img, label = dataSet[0] print(img.size(), img.float().mean()) print(label)
数据预处理
通常在加载数据时需要进行数据处理;
torchvision 是一个视觉工具包,提供了很多视觉图像处理的工具,其中 transforms 提供了对 PIL Image 和 Tensor 对象的常用操作
PIL Image 常见操作我后续会专门写一篇博客;
Tensor 常见操作
- Normalize:标准化,减去均值除以标准差
- ToPILImage:将 Tensor 转成 PILImage 对象
这些操作定义后都以对象的形式存在,真正使用时调用它的 __call__ 方法
# 正确写法 show = transforms.ToPILImage() image = show(img) # 错误写法 show = transforms.ToPILImage(img)
另外需注意:
1. Compose 将多个操作拼接起来,类似于 nn.Sequential
2. transforms.Lambda 方法支持用户自定义数据处理策略
示例
import os import numpy as np import torch as t from torch.utils import data from PIL import Image from torchvision import transforms import matplotlib.pylab as plt transform = transforms.Compose([ transforms.Resize(224), ### 缩放图片,长宽比不变,最短边为 224 像素 transforms.CenterCrop(224), ### 从中间裁剪出 224x224 transforms.ToTensor(), ### image to tensor,并归一化至 [0, 1] transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5]) ### 标准化至 [-1, 1] ]) class DogCat(data.Dataset): def __init__(self, path, transforms=None): images = os.listdir(path) self.imgs = [os.path.join(path, i) for i in images] self.transforms = transforms def __getitem__(self, index): self.img = self.imgs[index] data = Image.open(self.img) label = 0 if 'dog' in self.img.split('\')[-1] else 1 if self.transforms: ### 自定义转换操作 mytrans = transforms.Lambda(lambda img: img.rotate(np.random.rand() * 360)) data = mytrans(data) data = self.transforms(data) return data, label def __len__(self): return len(self.imgs) dataSet = DogCat(r'F:dl_datasetcat_dog rain', transforms=transform) img, label = dataSet[0] print(img.size()) print(label) show = transforms.ToPILImage() image = show((img+1)/2) image.show()
批量加载
DataSet 负责数据集的抽象,其 __getitem__ 方法每次获取一个样本,这不利于网络的训练,我们需要 batch、shuffle、甚至并行;
在 pytorch 中使用 DataLoader 实现上述需求
class DataLoader(object): r""" Data loader. Combines a dataset and a sampler, and provides an iterable over the given dataset. """ __initialized = False def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None):
参数解释:
- dataset、batch_size、shuffle 不解释
- sampler:样本抽样,如随机采样,RandomSampler,当 shuffle 为 True 时,系统自动调用该采样方式,实现打乱数据;默认采样器为 SequentialSampler,逐个采样;WeightRandomSampler,在类别不均衡问题中,这种方式可以实现重采样
- batch_sampler:
- num_workers:使用多进程加载数据, 0 代表不使用多进程
- collate_fn:如何将多个样本拼接成一个 batch,一般使用默认的方式即可
- pin_memory:是否将数据保存在 pin memory 区,pin memory 区的数据转到 GPU 会更快一些
- drop_last:如果 dataset 中最后的数据不足一个 batch,弃掉
DataLoader 生成的数据类似于一个迭代器,有两种方式读取该数据
示例
from torch.utils.data import DataLoader dataloader = DataLoader(dataSet, batch_size=3, shuffle=True, num_workers=0, drop_last=False) ### 类似于迭代器 ### 批量获取数据有两种方式 # 方式1 dataiter = iter(dataloader) imgs, label = next(dataiter) print(imgs.size()) show = transforms.ToPILImage() image = show((imgs[0]+1)/2) image.show() # 方式2 for batch_data, batch_label in dataloader: print(batch_label)
WeightRandomSampler
加载异常
在数据处理中,有时会遇到某个样本无法读取的情况,如图片已损坏,此时 __getitem__ 函数将出现异常,处理方式有几种:
1. 剔除错误样本
2. __getitem__ 返回 None,然后自定义 collate_fn,将空过滤掉,但是这种情况获取的 batch 会少于 batch_size
3. 随机找一张代替
4. 提前进行数据清洗
DataSet 和 DataLoader 使用建议
1. 高负载的操作放在 __getitem__ 中,如图片读取
// 多进程会并行的调用 __getitem__ 方法,高负载的操作并行执行提高效率
2. dataSet 中尽量只包含只读对象,避免修改任何可变对象
// 线程安全问题
参考资料: