自定义加载数据
torch.utils.data.Dataset是一个抽象类,用户想要加载自定义的数据只需要继承这个类,并且覆写其中的两个方法即可:
- __len__:实现len(dataset)返回整个数据集的大小。
- __getitem__用来获取一些索引的数据,使dataset[i]返回数据集中第i个样本。
- 不覆写这两个方法会直接返回错误。
from torch.utils.data import DataLoader,Dataset
class MyData(Dataset): #继承Dataset
def __init__(self, root_dir, transform=None): #初始化图片路径,一些变换操作。
self.root_dir = root_dir #文件目录
self.transform = transform #变换
self.images = os.listdir(self.root_dir)#目录里的所有文件
def __len__(self):#返回整个数据集的大小
return len(self.images)
def __getitem__(self,index):#根据索引index返回dataset[index]
image_index = self.images[index]#根据索引index获取该图片
img_path = os.path.join(self.root_dir, image_index)#获取索引为index的图片的路径名
img = io.imread(img_path)# 读取该图片
label = img_path.split('\')[-1].split('.')[0]# 根据该图片的路径名获取该图片的label
sample = {'image':img,'label':label}#根据图片和标签创建字典
if self.transform:
sample = self.transform(sample)#对样本进行变换
return sample #返回该样本
之后使用torch.utils.data.DataLoader加载数据
data = MyData('path',transform=None)#初始化类,设置数据集所在路径以及变换
dataloader = DataLoader(data,batch_size=128,shuffle=True)#使用DataLoader加载数据
加载时不要涉及预处理,把该预处理的都提前做完。比如resize事先处理完,crop,flip和normalize在加载时候处理。