zoukankan      html  css  js  c++  java
  • pytorch (四) 数据加载

    自定义加载数据

    torch.utils.data.Dataset是一个抽象类,用户想要加载自定义的数据只需要继承这个类,并且覆写其中的两个方法即可:

    1. __len__:实现len(dataset)返回整个数据集的大小。
    2. __getitem__用来获取一些索引的数据,使dataset[i]返回数据集中第i个样本。
    3. 不覆写这两个方法会直接返回错误。
    from torch.utils.data import DataLoader,Dataset
    class MyData(Dataset): #继承Dataset
        def __init__(self, root_dir, transform=None): #初始化图片路径,一些变换操作。
            self.root_dir = root_dir   #文件目录
            self.transform = transform #变换
            self.images = os.listdir(self.root_dir)#目录里的所有文件
        
        def __len__(self):#返回整个数据集的大小
            return len(self.images)
        
        def __getitem__(self,index):#根据索引index返回dataset[index]
            image_index = self.images[index]#根据索引index获取该图片
            img_path = os.path.join(self.root_dir, image_index)#获取索引为index的图片的路径名
            img = io.imread(img_path)# 读取该图片
            label = img_path.split('\')[-1].split('.')[0]# 根据该图片的路径名获取该图片的label
            sample = {'image':img,'label':label}#根据图片和标签创建字典
            
            if self.transform:
                sample = self.transform(sample)#对样本进行变换
            return sample #返回该样本
    

    之后使用torch.utils.data.DataLoader加载数据

    data = MyData('path',transform=None)#初始化类,设置数据集所在路径以及变换
    dataloader = DataLoader(data,batch_size=128,shuffle=True)#使用DataLoader加载数据
    
    

    加载时不要涉及预处理,把该预处理的都提前做完。比如resize事先处理完,crop,flip和normalize在加载时候处理。

  • 相关阅读:
    第01组 Beta冲刺(5-5)
    第01组 Beta冲刺(4-5)
    第01组 Beta冲刺(3-5)
    第01组 Beta冲刺(2-5)
    第01组 Beta冲刺(1-5)
    软工实践个人总结
    第03组 每周小结 (3/3)
    第03组 每周小结 (2/3)
    第03组 每周小结 (1/3)
    第03组 Beta冲刺 总结
  • 原文地址:https://www.cnblogs.com/leimu/p/13366688.html
Copyright © 2011-2022 走看看