zoukankan      html  css  js  c++  java
  • PyTorch自定义数据集

    数据传递机制

    我们首先回顾识别手写数字的程序:

    ...
    Dataset = torchvision.datasets.MNIST(root='./mnist/', train=True, transform=transform, download=True,)
    dataloader = torch.utils.data.DataLoader(dataset=Dataset, batch_size=64, shuffle=True)
    ...
    for epoch in range(EPOCH):
        for i, (image, label) in enumerate(dataloader):
            ...

    从上面的程序,我们可以知道,在PyTorch中,数据传递机制是这样的:

    1. 创建Dataset
    2. Dataset传递给DataLoader
    3. DataLoader迭代产生训练数据提供给模型

    总结这个数据传递机制就是,Dataset负责建立索引到样本的映射,DataLoader负责以特定的方式从数据集中迭代的产生一个个batch的样本集合。在enumerate过程中实际上是dataloader按照其参数sampler规定的策略调用了其dataset的getitem方法(下文中将介绍该方法)。关于Dataloder和Dataset的关系,具体可参考博客PyTorch中Dataset, DataLoader, Sampler的关系

    在上面的识别手写数字的例子中,数据集是直接下载的,但如果我们自己收集了一些数据,存在电脑文件夹里,我们该如何把这些数据变为可以在PyTorch框架下进行神经网络训练的数据集呢,即如何自定义数据集呢?

    自定义数据集

    torch.utils.data.Dataset 是一个表示数据集的抽象类。任何自定义的数据集都需要继承这个类并覆写相关方法。所谓数据集,其实就是一个负责处理索引(index)到样本(sample)映射的一个类(class)。Pytorch提供两种数据集: Map式数据集 Iterable式数据集。这里我们只介绍前者。

    一个Map式的数据集必须要重写getitem(self, index)、 len(self) 两个内建方法,用来表示从索引到样本的映射(Map)。这样一个数据集dataset,举个例子,当使用dataset[idx]命令时,可以在你的硬盘中读取数据集中第idx张图片以及其标签(如果有的话); len(dataset)则会返回这个数据集的容量。

    自定义数据集类的范式大致是这样的:

    class CustomDataset(torch.utils.data.Dataset):#需要继承torch.utils.data.Dataset
        def __init__(self):
            # TODO
            # 1. Initialize file path or list of file names.
            pass
        def __getitem__(self, index):
            # TODO
            # 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
            # 2. Preprocess the data (e.g. torchvision.Transform).
            # 3. Return a data pair (e.g. image and label).
            #这里需要注意的是,第一步:read one data,是一个data point
            pass
        def __len__(self):
            # You should change 0 to the total size of your dataset.
            return 0
    

    根据这个范式,我们举一个例子。

    实例

    从kaggle官网下载dogsVScats的数据集(百度网盘下载链接见文末),该数据集包含test1文件夹和train文件夹,train文件夹中包含12500张猫的图片和12500张狗的图片,图片的文件名中带序号:

    cat.0.jpg
    cat.1.jpg
    cat.2.jpg
    ...
    cat.12499.jpg
    dog.0.jpg
    dog.1.jpg
    dog.2.jpg
    ...
    dog.12499.jpg

    我们把其中前10000张猫的图片和10000张狗的图片作为训练集,把后面的2500张猫的图片和2500张狗的图片作为验证集。猫的label记为0,狗的label记为1。因为图片大小不一,所以,我们需要对图像进行transform。

    import matplotlib.pyplot as plt
    import numpy as np
    import torch
    from torch.utils.data import Dataset, DataLoader
    from torchvision import transforms
    from PIL import Image
    import os
    image_transform = transforms.Compose([
        transforms.Resize(256),               # 把图片resize为256*256
        transforms.RandomCrop(224),           # 随机裁剪224*224
        transforms.RandomHorizontalFlip(),    # 水平翻转
        transforms.ToTensor(),                # 将图像转为Tensor
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])   # 标准化
    ])
    
    class DogVsCatDataset(Dataset):   # 创建一个叫做DogVsCatDataset的Dataset,继承自父类torch.utils.data.Dataset
        def __init__(self, root_dir, train=True, transform=None):
            """
            Args:
                root_dir (string): Directory with all the images.
                transform (callable, optional): Optional transform to be applied on a sample.
            """
            self.root_dir = root_dir
            self.img_path = os.listdir(self.root_dir)
            if train:
                self.img_path = list(filter(lambda x: int(x.split('.')[1]) < 10000, self.img_path))    # 划分训练集和验证集
            else:
                self.img_path = list(filter(lambda x: int(x.split('.')[1]) >= 10000, self.img_path))
            self.transform = transform
    
        def __len__(self):
            return len(self.img_path)
    
        def __getitem__(self, idx):
            image = Image.open(os.path.join(self.root_dir, self.img_path[idx]))
            label = 0 if self.img_path[idx].split('.')[0] == 'cat' else 1        # label, 猫为0,狗为1
            if self.transform:
                image = self.transform(image)
            label = torch.from_numpy(np.array([label]))
            return image, label

    我们来测试一下:

    if __name__ == '__main__':
        catanddog_dataset = DogVsCatDataset(root_dir='/Users/wangpeng/Desktop/train',
                                            train=False,
                                            transform=image_transform)
        train_loader = DataLoader(catanddog_dataset, batch_size=8, shuffle=True, num_workers=4)   # num_workers=4表示用4个线程读取数据
        image, label = iter(train_loader).next()   # iter()函数把train_loader变为迭代器,然后调用迭代器的next()方法
        sample = image[0].squeeze()
        sample = sample.permute((1, 2, 0)).numpy()
        sample *= [0.229, 0.224, 0.225]
        sample += [0.485, 0.456, 0.406]
        sample = np.clip(sample, 0, 1)
        plt.imshow(sample)
        plt.show()
        print('Label is: {}'.format(label[0].numpy()))

    运行结果:

    Label is: [0] 

    dogsVScats数据下载链接:链接:https://pan.baidu.com/s/17768gqeaX9NrdURV_tR_ow  提取密码:478x

    参考文献

    [1] Pytorch之Dataset与DataLoader,打造你自己的数据集

    [2] 基于PyTorch的卷积神经网络图像分类——猫狗大战(一):使用Pytorch定义DataLoader

  • 相关阅读:
    电源跳闸或突然断电后Kafka启动失败问题
    Failure to find org.glassfish:javax.el:pom:3.0.1b06SNAPSHOT
    Idea中的maven工程运行Scala报Command execution failed
    Scala(一)入门
    HBase2.0.5
    GridView之CommandField的妙用——点击提示删除
    SharePoint 2010在新窗口打开文档库中的文件
    SharePoint 2010 使用后台代码向SP.UI.ModalDialog.showModalDialog传值
    SharePoint 2010 使用代码创建视图查询条件
    Sharepoint 2010 禁止用户在文档库的第一级(根)目录上传文件
  • 原文地址:https://www.cnblogs.com/picassooo/p/12846617.html
Copyright © 2011-2022 走看看