zoukankan      html  css  js  c++  java
  • pytorch (四) 数据加载

    自定义加载数据

    torch.utils.data.Dataset是一个抽象类,用户想要加载自定义的数据只需要继承这个类,并且覆写其中的两个方法即可:

    1. __len__:实现len(dataset)返回整个数据集的大小。
    2. __getitem__用来获取一些索引的数据,使dataset[i]返回数据集中第i个样本。
    3. 不覆写这两个方法会直接返回错误。
    from torch.utils.data import DataLoader,Dataset
    class MyData(Dataset): #继承Dataset
        def __init__(self, root_dir, transform=None): #初始化图片路径,一些变换操作。
            self.root_dir = root_dir   #文件目录
            self.transform = transform #变换
            self.images = os.listdir(self.root_dir)#目录里的所有文件
        
        def __len__(self):#返回整个数据集的大小
            return len(self.images)
        
        def __getitem__(self,index):#根据索引index返回dataset[index]
            image_index = self.images[index]#根据索引index获取该图片
            img_path = os.path.join(self.root_dir, image_index)#获取索引为index的图片的路径名
            img = io.imread(img_path)# 读取该图片
            label = img_path.split('\')[-1].split('.')[0]# 根据该图片的路径名获取该图片的label
            sample = {'image':img,'label':label}#根据图片和标签创建字典
            
            if self.transform:
                sample = self.transform(sample)#对样本进行变换
            return sample #返回该样本
    

    之后使用torch.utils.data.DataLoader加载数据

    data = MyData('path',transform=None)#初始化类,设置数据集所在路径以及变换
    dataloader = DataLoader(data,batch_size=128,shuffle=True)#使用DataLoader加载数据
    
    

    加载时不要涉及预处理,把该预处理的都提前做完。比如resize事先处理完,crop,flip和normalize在加载时候处理。

  • 相关阅读:
    poj 1328 Radar Installation (贪心)
    hdu 2037 今年暑假不AC (贪心)
    poj 2965 The Pilots Brothers' refrigerator (dfs)
    poj 1753 Flip Game (dfs)
    hdu 2838 Cow Sorting (树状数组)
    hdu 1058 Humble Numbers (DP)
    hdu 1069 Monkey and Banana (DP)
    hdu 1087 Super Jumping! Jumping! Jumping! (DP)
    必须知道的.NET FrameWork
    使用记事本+CSC编译程序
  • 原文地址:https://www.cnblogs.com/leimu/p/13366688.html
Copyright © 2011-2022 走看看