zoukankan      html  css  js  c++  java
  • PyTorch 训练前对数据加载、预处理

    参考:pytorch torchvision transform官方文档

    Pytorch学习--编程实战:猫和狗二分类

    深度学习框架PyTorch一书的学习-第五章-常用工具模块

    # coding:utf8
    import os
    from PIL import Image
    from torch.utils import data
    import numpy as np
    from torchvision import transforms as T
    
    
    class DogCat(data.Dataset):
    
        def __init__(self, root, transforms=None, train=True, test=False):
            ''''''
            '''
            主要目标: 获取所有图片的地址,并根据训练,验证,测试划分数据
            '''
            self.test = test
            imgs = [os.path.join(root, img) for img in os.listdir(root)]
    
            # test1: data/test1/8973.jpg
            # train: data/train/cat.10004.jpg
            if self.test:
                imgs = sorted(imgs, key=lambda x: int(x.split('.')[-2].split('/')[-1]))
            else:
                imgs = sorted(imgs, key=lambda x: int(x.split('.')[-2]))
    
            imgs_num = len(imgs)
    
            # shuffle imgs 打乱图片顺序
            np.random.seed(100)
            imgs = np.random.permutation(imgs)
    
            #训练集和验证集比例为7:3
            if self.test:
                self.imgs = imgs
            elif train:
                self.imgs = imgs[:int(0.7 * imgs_num)]  #训练集
            else:
                self.imgs = imgs[int(0.7 * imgs_num):]  #验证集
    
            if transforms is None:
                normalize = T.Normalize(mean=[0.485, 0.456, 0.406], #数据归一化处理
                                        std=[0.229, 0.224, 0.225])
    
                if self.test or not train:  #测试集+验证集
                    self.transforms = T.Compose([   #用来管理各个transform
                        T.Scale(224),#将输入的`PIL.Image`重新改变大小成给定的`size`,`size`是最小边的边长。
                        举个例子,如果原图的`height>width`,那么改变大小后的图片大小是
                        `(size*height/width, size)`。 T.CenterCrop(224), #以输入图像img的中心作为中心点进行指定size的crop操作 T.ToTensor(), #在做数据归一化之前必须要把PIL Image转成Tensor normalize #数据归一化处理 ]) else: #训练集 self.transforms = T.Compose([ T.Scale(256), T.RandomSizedCrop(224),#先将给定的PIL.Image随机切,然后再resize成给定的size大小。 T.RandomHorizontalFlip(),#随机水平翻转给定的PIL.Image,概率为0.5。即:
                一半的概率翻转,一半的概率不翻转。 T.ToTensor(), normalize ]) def __getitem__(self, index):
    '''''' ''' 一次返回一张图片的数据,并为训练集和验证集打标签 ''' img_path = self.imgs[index] if self.test: #测试集 label = int(self.imgs[index].split('.')[-2].split('/')[-1]) else: #验证集 训练集 label = 1 if 'dog' in img_path.split('/')[-1] else 0 data = Image.open(img_path) data = self.transforms(data) return data, label def __len__(self): return len(self.imgs)
  • 相关阅读:
    【转】Tomcat中部署java web应用程序
    【转】如何安装mysql服务
    【转】java_web开发入门
    【转】SVN 查看历史信息
    【转】java编译错误 程序包javax.servlet不存在javax.servlet.*
    【转】MySQL5安装的图解(mysql-5.0.27-win32.zip)
    【转】JAVA变量path , classpth ,java_home设设置作用和作用
    intellij idea 10.5介绍
    Java中的IO与NIO
    javaWeb完成注册功能
  • 原文地址:https://www.cnblogs.com/cekong/p/11155836.html
Copyright © 2011-2022 走看看