zoukankan      html  css  js  c++  java
  • pytorch (四) 数据加载

    自定义加载数据

    torch.utils.data.Dataset是一个抽象类,用户想要加载自定义的数据只需要继承这个类,并且覆写其中的两个方法即可:

    1. __len__:实现len(dataset)返回整个数据集的大小。
    2. __getitem__用来获取一些索引的数据,使dataset[i]返回数据集中第i个样本。
    3. 不覆写这两个方法会直接返回错误。
    from torch.utils.data import DataLoader,Dataset
    class MyData(Dataset): #继承Dataset
        def __init__(self, root_dir, transform=None): #初始化图片路径,一些变换操作。
            self.root_dir = root_dir   #文件目录
            self.transform = transform #变换
            self.images = os.listdir(self.root_dir)#目录里的所有文件
        
        def __len__(self):#返回整个数据集的大小
            return len(self.images)
        
        def __getitem__(self,index):#根据索引index返回dataset[index]
            image_index = self.images[index]#根据索引index获取该图片
            img_path = os.path.join(self.root_dir, image_index)#获取索引为index的图片的路径名
            img = io.imread(img_path)# 读取该图片
            label = img_path.split('\')[-1].split('.')[0]# 根据该图片的路径名获取该图片的label
            sample = {'image':img,'label':label}#根据图片和标签创建字典
            
            if self.transform:
                sample = self.transform(sample)#对样本进行变换
            return sample #返回该样本
    

    之后使用torch.utils.data.DataLoader加载数据

    data = MyData('path',transform=None)#初始化类,设置数据集所在路径以及变换
    dataloader = DataLoader(data,batch_size=128,shuffle=True)#使用DataLoader加载数据
    
    

    加载时不要涉及预处理,把该预处理的都提前做完。比如resize事先处理完,crop,flip和normalize在加载时候处理。

  • 相关阅读:
    阅读INI档
    jQuery遍历table中间tr td并获得td价值
    PB控制性能TreeView
    [POJ 3311]Hie with the Pie——谈论TSP难题DP解决方法
    数据结构 线性表
    ORACLE11G RAC 施加以分离不同的实例.TAF
    一起学习android使用一个回调函数onCreateDialog实现负载对话(23)
    [cocos2d-x 3.0] 触摸显示器
    lua学习笔记10:lua简单的命令行
    Matlab图像处理系列4———傅立叶变换和反变换的图像
  • 原文地址:https://www.cnblogs.com/leimu/p/13366688.html
Copyright © 2011-2022 走看看