zoukankan      html  css  js  c++  java
  • pytorch 读数据接口 制作数据集 data.dataset

    【吐槽】

    啊,代码,你这个大猪蹄子

    自己写了cifar10的数据接口,跟官方接口load的数据一样,

    沾沾自喜,以为自己会写数据接口了

    几天之后,突然想,自己的代码为啥有点慢呢,这数据集不大啊

    用了官方接口,真快啊。。。

    啊啊啊啊啊啊啊啊

    但这是好事,至少我明白了一点知识对吧

    【lesson】

    看了cifar10的接口,发现自己在数据集初始化的地方写的太少了,应该在初始化的时候就把所有数据读进来,这样的话在__getitem__的时候才能快。

     人家的初始化:

     if self.train:
                self.train_data = []
                self.train_labels = []
                for fentry in self.train_list:
                    f = fentry[0]
                    file = os.path.join(self.root, self.base_folder, f)
                    fo = open(file, 'rb')
                    if sys.version_info[0] == 2:
                        entry = pickle.load(fo)
                    else:
                        entry = pickle.load(fo, encoding='latin1')
                    self.train_data.append(entry['data'])
                    if 'labels' in entry:
                        self.train_labels += entry['labels']
                    else:
                        self.train_labels += entry['fine_labels']
                    fo.close()
    
                self.train_data = np.concatenate(self.train_data)
                self.train_data = self.train_data.reshape((50000, 3, 32, 32))
                self.train_data = self.train_data.transpose((0, 2, 3, 1))  # convert to HWC

    人家的getitem

        def __getitem__(self, index):
            """
            Args:
                index (int): Index
    
            Returns:
                tuple: (image, target) where target is index of the target class.
            """
            if self.train:
                img, target = self.train_data[index], self.train_labels[index]
            else:
                img, target = self.test_data[index], self.test_labels[index]
    
            # doing this so that it is consistent with all other datasets
            # to return a PIL Image
            img = Image.fromarray(img)
    
            if self.transform is not None:
                img = self.transform(img)
    
            if self.target_transform is not None:
                target = self.target_transform(target)
    
            return img, target

    自己:(都写到getitem里面了)

     def __init__(self, root, transforms=transform(), train=True, test=False):
            self.root = root
            self.transform = transforms
            self.train = train
            self.test = test
            if self.test:
                self.train = False
    
        def __getitem__(self, item):
            x = math.floor(item / 10000) + 1
            y = item % 10000
            if not self.train and not self.test:
                x = 5
                y = 5000+item
    
            imgpath = os.path.join(self.root, "data_batch_"+str(x))
            with open(imgpath, 'rb') as fo:
                dict = pickle.load(fo, encoding='bytes')
                d_decode = {}
                for k,v in dict.items():
                    d_decode[k.decode('utf8')] = v
                dict = d_decode
                data = dict['data'][y]  # 3*32*32==3072
                data = np.reshape(data,(3,32,32))
                data = data.transpose(1,2,0)
                data = self.transform(data)
                label = dict['labels'][y]
                # label = torch.from_numpy(label)
    
                return data, label

    附自己的代码和人家的代码全部

    人家:

     1 base_folder = 'cifar-10-batches-py'
     2     url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
     3     filename = "cifar-10-python.tar.gz"
     4     tgz_md5 = 'c58f30108f718f92721af3b95e74349a'
     5     train_list = [
     6         ['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
     7         ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],
     8         ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],
     9         ['data_batch_4', '634d18415352ddfa80567beed471001a'],
    10         ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'],
    11     ]
    12 
    13     test_list = [
    14         ['test_batch', '40351d587109b95175f43aff81a1287e'],
    15     ]
    16 
    17     def __init__(self, root, train=True,
    18                  transform=None, target_transform=None,
    19                  download=False):
    20         self.root = os.path.expanduser(root)
    21         self.transform = transform
    22         self.target_transform = target_transform
    23         self.train = train  # training set or test set
    24 
    25         if download:
    26             self.download()
    27 
    28         if not self._check_integrity():
    29             raise RuntimeError('Dataset not found or corrupted.' +
    30                                ' You can use download=True to download it')
    31 
    32         # now load the picked numpy arrays
    33         if self.train:
    34             self.train_data = []
    35             self.train_labels = []
    36             for fentry in self.train_list:
    37                 f = fentry[0]
    38                 file = os.path.join(self.root, self.base_folder, f)
    39                 fo = open(file, 'rb')
    40                 if sys.version_info[0] == 2:
    41                     entry = pickle.load(fo)
    42                 else:
    43                     entry = pickle.load(fo, encoding='latin1')
    44                 self.train_data.append(entry['data'])
    45                 if 'labels' in entry:
    46                     self.train_labels += entry['labels']
    47                 else:
    48                     self.train_labels += entry['fine_labels']
    49                 fo.close()
    50 
    51             self.train_data = np.concatenate(self.train_data)
    52             self.train_data = self.train_data.reshape((50000, 3, 32, 32))
    53             self.train_data = self.train_data.transpose((0, 2, 3, 1))  # convert to HWC
    54         else:
    55             f = self.test_list[0][0]
    56             file = os.path.join(self.root, self.base_folder, f)
    57             fo = open(file, 'rb')
    58             if sys.version_info[0] == 2:
    59                 entry = pickle.load(fo)
    60             else:
    61                 entry = pickle.load(fo, encoding='latin1')
    62             self.test_data = entry['data']
    63             if 'labels' in entry:
    64                 self.test_labels = entry['labels']
    65             else:
    66                 self.test_labels = entry['fine_labels']
    67             fo.close()
    68             self.test_data = self.test_data.reshape((10000, 3, 32, 32))
    69             self.test_data = self.test_data.transpose((0, 2, 3, 1))  # convert to HWC
    70 
    71     def __getitem__(self, index):
    72         """
    73         Args:
    74             index (int): Index
    75 
    76         Returns:
    77             tuple: (image, target) where target is index of the target class.
    78         """
    79         if self.train:
    80             img, target = self.train_data[index], self.train_labels[index]
    81         else:
    82             img, target = self.test_data[index], self.test_labels[index]
    83 
    84         # doing this so that it is consistent with all other datasets
    85         # to return a PIL Image
    86         img = Image.fromarray(img)
    87 
    88         if self.transform is not None:
    89             img = self.transform(img)
    90 
    91         if self.target_transform is not None:
    92             target = self.target_transform(target)
    93 
    94         return img, target
  • 相关阅读:
    1.Vue.js的常用指令
    爬虫
    对key中有数字的字典进行排序
    微信公众号服务器验证的坑
    Linux之正则表达式grep
    Oracle之select
    Linux之文件与目录管理
    Linux之vim、压缩与解压缩
    Linux之文件权限、用户管理
    Oracle学习入门
  • 原文地址:https://www.cnblogs.com/yexiaoqi/p/10510960.html
Copyright © 2011-2022 走看看