zoukankan      html  css  js  c++  java
  • [转载]pytorch自定义数据集

    为什么要定义Datasets:

    PyTorch提供了一个工具函数torch.utils.data.DataLoader。通过这个类,我们在准备mini-batch的时候可以多线程并行处理,这样可以加快准备数据的速度。Datasets就是构建这个类的实例的参数之一。

    如何自定义Datasets

    下面是一个自定义Datasets的框架:

    class CustomDataset(data.Dataset):#需要继承data.Dataset
        def __init__(self):
            # TODO
            # 1. Initialize file path or list of file names.
            pass
        def __getitem__(self, index):
            # TODO
            # 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
            # 2. Preprocess the data (e.g. torchvision.Transform).
            # 3. Return a data pair (e.g. image and label).
            #这里需要注意的是,第一步:read one data,是一个data
            pass
        def __len__(self):
            # You should change 0 to the total size of your dataset.
            return 0

    下面看一下官方MNIST的例子(代码被缩减,只留下了重要的部分):

    class MNIST(data.Dataset):
        def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
            self.root = root
            self.transform = transform
            self.target_transform = target_transform
            self.train = train  # training set or test set
     
            if download:
                self.download()
     
            if not self._check_exists():
                raise RuntimeError('Dataset not found.' +
                                   ' You can use download=True to download it')
     
            if self.train:
                self.train_data, self.train_labels = torch.load(
                    os.path.join(root, self.processed_folder, self.training_file))
            else:
                self.test_data, self.test_labels = torch.load(os.path.join(root, self.processed_folder, self.test_file))
     
        def __getitem__(self, index):
            if self.train:
                img, target = self.train_data[index], self.train_labels[index]
            else:
                img, target = self.test_data[index], self.test_labels[index]
     
            # doing this so that it is consistent with all other datasets
            # to return a PIL Image
            img = Image.fromarray(img.numpy(), mode='L')
     
            if self.transform is not None:
                img = self.transform(img)
     
            if self.target_transform is not None:
                target = self.target_transform(target)
     
            return img, target
     
        def __len__(self):
            if self.train:
                return 60000
            else:
                return 10000
  • 相关阅读:
    20140830 函数 递归
    函数 20140829
    结构体20140827
    20140826 集合
    20140822数组,应用举例
    140821 字符串,数字,日期及应用举例
    20140819 例子
    HTML基础
    登陆远程服务器
    索引 视图 游标
  • 原文地址:https://www.cnblogs.com/wzyuan/p/9461868.html
Copyright © 2011-2022 走看看