zoukankan      html  css  js  c++  java
  • PyTorch【8】-数据加载

    数据集对象

    在 pytorch 中,数据加载可以通过自定义数据集对象实现;

    数据集对象被抽象为 DataSet 类;

    自定义数据集对象,需要继承该类,并且实现 __getitem__ 和 __len__ 两个方法

    示例

    class DogCat(data.Dataset):
        def __init__(self, path):
            images = os.listdir(path)
            self.imgs = [os.path.join(path, i) for i in images]
    
        def __getitem__(self, index):
            ### 只是简单读取图片,并未做任何处理
            img_path = self.imgs[index]
            label = 1 if 'dog' in img_path.split('\')[-1] else 0
            pil_img = Image.open(img_path)
    
            array = np.asarray(pil_img)
            data = t.from_numpy(array)
            return data, label
    
        def __len__(self):
            return len(self.imgs)
    
    dataSet = DogCat(r'F:dl_datasetcat_dog	rain')
    img, label = dataSet[0]
    print(img.size(), img.float().mean())
    print(label)

    数据预处理

    通常在加载数据时需要进行数据处理;

    torchvision 是一个视觉工具包,提供了很多视觉图像处理的工具,其中 transforms 提供了对 PIL Image 和 Tensor 对象的常用操作

    PIL Image 常见操作我后续会专门写一篇博客;

    Tensor 常见操作

    • Normalize:标准化,减去均值除以标准差
    • ToPILImage:将 Tensor 转成 PILImage 对象

    这些操作定义后都以对象的形式存在,真正使用时调用它的 __call__ 方法

    # 正确写法
    show = transforms.ToPILImage()
    image = show(img)
    # 错误写法
    show = transforms.ToPILImage(img)

    另外需注意:

    1. Compose 将多个操作拼接起来,类似于 nn.Sequential

    2. transforms.Lambda 方法支持用户自定义数据处理策略

    示例

    import os
    import numpy as np
    import torch as t
    from torch.utils import data
    from PIL import Image
    from torchvision import transforms
    import matplotlib.pylab as plt
    
    
    transform = transforms.Compose([
        transforms.Resize(224),     ### 缩放图片,长宽比不变,最短边为 224 像素
        transforms.CenterCrop(224), ### 从中间裁剪出 224x224
        transforms.ToTensor(),      ### image to tensor,并归一化至 [0, 1]
        transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])   ### 标准化至 [-1, 1]
    
    ])
    
    
    class DogCat(data.Dataset):
        def __init__(self, path, transforms=None):
            images = os.listdir(path)
            self.imgs = [os.path.join(path, i) for i in images]
            self.transforms = transforms
    
        def __getitem__(self, index):
            self.img = self.imgs[index]
            data = Image.open(self.img)
            label = 0 if 'dog' in self.img.split('\')[-1] else 1
    
            if self.transforms:
                ### 自定义转换操作
                mytrans = transforms.Lambda(lambda img: img.rotate(np.random.rand() * 360))
                data = mytrans(data)
                data = self.transforms(data)
    
            return data, label
    
        def __len__(self):
            return len(self.imgs)
    
    dataSet = DogCat(r'F:dl_datasetcat_dog	rain', transforms=transform)
    img, label = dataSet[0]
    print(img.size())
    print(label)
    
    show = transforms.ToPILImage()
    image = show((img+1)/2)
    image.show()

    批量加载

    DataSet 负责数据集的抽象,其 __getitem__ 方法每次获取一个样本,这不利于网络的训练,我们需要 batch、shuffle、甚至并行;

    在 pytorch 中使用 DataLoader 实现上述需求

    class DataLoader(object):
        r"""
        Data loader. Combines a dataset and a sampler, and provides an iterable over
        the given dataset.
        """
    
        __initialized = False
    
        def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
                     batch_sampler=None, num_workers=0, collate_fn=None,
                     pin_memory=False, drop_last=False, timeout=0,
                     worker_init_fn=None, multiprocessing_context=None):

    参数解释:

    • dataset、batch_size、shuffle 不解释
    • sampler:样本抽样,如随机采样,RandomSampler,当 shuffle 为 True 时,系统自动调用该采样方式,实现打乱数据;默认采样器为 SequentialSampler,逐个采样;WeightRandomSampler,在类别不均衡问题中,这种方式可以实现重采样
    • batch_sampler:
    • num_workers:使用多进程加载数据, 0 代表不使用多进程
    • collate_fn:如何将多个样本拼接成一个 batch,一般使用默认的方式即可
    • pin_memory:是否将数据保存在 pin memory 区,pin memory 区的数据转到 GPU 会更快一些
    • drop_last:如果 dataset 中最后的数据不足一个 batch,弃掉

    DataLoader 生成的数据类似于一个迭代器,有两种方式读取该数据

    示例

    from torch.utils.data import DataLoader
    
    dataloader = DataLoader(dataSet, batch_size=3, shuffle=True, num_workers=0, drop_last=False)    ### 类似于迭代器
    
    ### 批量获取数据有两种方式
    # 方式1
    dataiter = iter(dataloader)
    imgs, label = next(dataiter)
    print(imgs.size())
    
    show = transforms.ToPILImage()
    image = show((imgs[0]+1)/2)
    image.show()
    
    # 方式2
    for batch_data, batch_label in dataloader:
        print(batch_label)

    WeightRandomSampler

    加载异常

    在数据处理中,有时会遇到某个样本无法读取的情况,如图片已损坏,此时 __getitem__ 函数将出现异常,处理方式有几种:

    1. 剔除错误样本

    2. __getitem__ 返回 None,然后自定义 collate_fn,将空过滤掉,但是这种情况获取的 batch 会少于 batch_size

    3. 随机找一张代替

    4. 提前进行数据清洗

    DataSet 和 DataLoader 使用建议

    1. 高负载的操作放在 __getitem__ 中,如图片读取

      // 多进程会并行的调用 __getitem__ 方法,高负载的操作并行执行提高效率

    2. dataSet 中尽量只包含只读对象,避免修改任何可变对象

      // 线程安全问题

    参考资料:

  • 相关阅读:
    IDEA创建test测试类
    SpringBoot Unable to find a @SpringBootConfiguration, you need to use @ContextConfiguration
    Mysql在线加索引锁表归纳
    工作感悟--对上一份工作总结
    ESP8266获取网络NTP时间(转)
    Python中的CGI编程 config配置(windows、Apache) 以及后期的编写(转)
    CGI与FastCGI(转)
    JSON-RPC轻量级远程调用协议介绍及使用
    java插件化编程(动态加载)
    PF4J入门指南
  • 原文地址:https://www.cnblogs.com/yanshw/p/12222013.html
Copyright © 2011-2022 走看看