zoukankan      html  css  js  c++  java
  • PyTorch 训练前对数据加载、预处理

    参考:pytorch torchvision transform官方文档

    Pytorch学习--编程实战:猫和狗二分类

    深度学习框架PyTorch一书的学习-第五章-常用工具模块

    # coding:utf8
    import os
    from PIL import Image
    from torch.utils import data
    import numpy as np
    from torchvision import transforms as T
    
    
    class DogCat(data.Dataset):
    
        def __init__(self, root, transforms=None, train=True, test=False):
            ''''''
            '''
            主要目标: 获取所有图片的地址,并根据训练,验证,测试划分数据
            '''
            self.test = test
            imgs = [os.path.join(root, img) for img in os.listdir(root)]
    
            # test1: data/test1/8973.jpg
            # train: data/train/cat.10004.jpg
            if self.test:
                imgs = sorted(imgs, key=lambda x: int(x.split('.')[-2].split('/')[-1]))
            else:
                imgs = sorted(imgs, key=lambda x: int(x.split('.')[-2]))
    
            imgs_num = len(imgs)
    
            # shuffle imgs 打乱图片顺序
            np.random.seed(100)
            imgs = np.random.permutation(imgs)
    
            #训练集和验证集比例为7:3
            if self.test:
                self.imgs = imgs
            elif train:
                self.imgs = imgs[:int(0.7 * imgs_num)]  #训练集
            else:
                self.imgs = imgs[int(0.7 * imgs_num):]  #验证集
    
            if transforms is None:
                normalize = T.Normalize(mean=[0.485, 0.456, 0.406], #数据归一化处理
                                        std=[0.229, 0.224, 0.225])
    
                if self.test or not train:  #测试集+验证集
                    self.transforms = T.Compose([   #用来管理各个transform
                        T.Scale(224),#将输入的`PIL.Image`重新改变大小成给定的`size`,`size`是最小边的边长。
                        举个例子,如果原图的`height>width`,那么改变大小后的图片大小是
                        `(size*height/width, size)`。 T.CenterCrop(224), #以输入图像img的中心作为中心点进行指定size的crop操作 T.ToTensor(), #在做数据归一化之前必须要把PIL Image转成Tensor normalize #数据归一化处理 ]) else: #训练集 self.transforms = T.Compose([ T.Scale(256), T.RandomSizedCrop(224),#先将给定的PIL.Image随机切,然后再resize成给定的size大小。 T.RandomHorizontalFlip(),#随机水平翻转给定的PIL.Image,概率为0.5。即:
                一半的概率翻转,一半的概率不翻转。 T.ToTensor(), normalize ]) def __getitem__(self, index):
    '''''' ''' 一次返回一张图片的数据,并为训练集和验证集打标签 ''' img_path = self.imgs[index] if self.test: #测试集 label = int(self.imgs[index].split('.')[-2].split('/')[-1]) else: #验证集 训练集 label = 1 if 'dog' in img_path.split('/')[-1] else 0 data = Image.open(img_path) data = self.transforms(data) return data, label def __len__(self): return len(self.imgs)
  • 相关阅读:
    array_map()与array_shift()搭配使用 PK array_column()函数
    Educational Codeforces Round 8 D. Magic Numbers
    hdu 1171 Big Event in HDU
    hdu 2844 poj 1742 Coins
    hdu 3591 The trouble of Xiaoqian
    hdu 2079 选课时间
    hdu 2191 珍惜现在,感恩生活 多重背包入门题
    hdu 5429 Geometric Progression 高精度浮点数(java版本)
    【BZOJ】1002: [FJOI2007]轮状病毒 递推+高精度
    hdu::1002 A + B Problem II
  • 原文地址:https://www.cnblogs.com/cekong/p/11155836.html
Copyright © 2011-2022 走看看