zoukankan      html  css  js  c++  java
  • 关于torchvision.datasets.CIFAR10

    在Pytorch0.4版本的DARTS代码里,有一行代码是

    trn_data = datasets.CIFAR10(root=data_path, train=True, download=False, transform=train_transform)
    shape = trn_data.train_data.shape

    在1.2及以上版本里,查看源码可知,CIFAR10这个类已经没有train_data这个属性了,取而代之的是data,因此要把第二行改成

    shape = trn_data.data.shape

    datasets.CIFAR10源码如下:

    from __future__ import print_function
    from PIL import Image
    import os
    import os.path
    import numpy as np
    import sys
    
    if sys.version_info[0] == 2:
        import cPickle as pickle
    else:
        import pickle
    
    from .vision import VisionDataset
    from .utils import check_integrity, download_and_extract_archive
    
    
    [docs]class CIFAR10(VisionDataset):
        """`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
    
        Args:
            root (string): Root directory of dataset where directory
                ``cifar-10-batches-py`` exists or will be saved to if download is set to True.
            train (bool, optional): If True, creates dataset from training set, otherwise
                creates from test set.
            transform (callable, optional): A function/transform that takes in an PIL image
                and returns a transformed version. E.g, ``transforms.RandomCrop``
            target_transform (callable, optional): A function/transform that takes in the
                target and transforms it.
            download (bool, optional): If true, downloads the dataset from the internet and
                puts it in root directory. If dataset is already downloaded, it is not
                downloaded again.
    
        """
        base_folder = 'cifar-10-batches-py'
        url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
        filename = "cifar-10-python.tar.gz"
        tgz_md5 = 'c58f30108f718f92721af3b95e74349a'
        train_list = [
            ['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
            ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],
            ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],
            ['data_batch_4', '634d18415352ddfa80567beed471001a'],
            ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'],
        ]
    
        test_list = [
            ['test_batch', '40351d587109b95175f43aff81a1287e'],
        ]
        meta = {
            'filename': 'batches.meta',
            'key': 'label_names',
            'md5': '5ff9c542aee3614f3951f8cda6e48888',
        }
    
        def __init__(self, root, train=True, transform=None, target_transform=None,
                     download=False):
    
            super(CIFAR10, self).__init__(root, transform=transform,
                                          target_transform=target_transform)
    
            self.train = train  # training set or test set
    
            if download:
                self.download()
    
            if not self._check_integrity():
                raise RuntimeError('Dataset not found or corrupted.' +
                                   ' You can use download=True to download it')
    
            if self.train:
                downloaded_list = self.train_list
            else:
                downloaded_list = self.test_list
    
            self.data = []
            self.targets = []
    
            # now load the picked numpy arrays
            for file_name, checksum in downloaded_list:
                file_path = os.path.join(self.root, self.base_folder, file_name)
                with open(file_path, 'rb') as f:
                    if sys.version_info[0] == 2:
                        entry = pickle.load(f)
                    else:
                        entry = pickle.load(f, encoding='latin1')
                    self.data.append(entry['data'])
                    if 'labels' in entry:
                        self.targets.extend(entry['labels'])
                    else:
                        self.targets.extend(entry['fine_labels'])
    
            self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
            self.data = self.data.transpose((0, 2, 3, 1))  # convert to HWC
    
            self._load_meta()
  • 相关阅读:
    迭代器和生成器
    函数嵌套
    页面调用dll
    C++MFC之picture control控件铺满图片
    C++中去掉string字符串中的 等
    C++之map使用
    C++之条形码,windows下zint库的编译及应用(二)
    C++之条形码,windows下zint库的编译及应用(一)
    C++通过HTTP请求Get或Post方式请求Json数据(转)
    从长字符串中获取想要的字符串
  • 原文地址:https://www.cnblogs.com/yqpy/p/11831717.html
Copyright © 2011-2022 走看看