zoukankan      html  css  js  c++  java
  • Pytorch框架学习---(2)输入数据操作

    本节讲述Data如何利用Pytorch提供的DataLoader进行读取,以及Transforms的图片处理方式。 【文中思维导图采用MindMaster软件】

    注意:笼统总结Transforms,目前仅具体介绍裁剪、翻转、标准化,后续随着代码需要,再逐步更新。

    一. 数据读取(DataLoader和Dataset)

    1.DataLoader

      我们采用Pytorch提供的DataLoader进行数据Batch封装,其中需要定义dataset类。

    自定义的dataset类需要复写def getitem(self, index):函数!!!

    train_loader = DataLoader(dataset=train_dataset,
                              batch_size=Batch_Size,
                              shuffle=True)
    
    for epoch in range(Max_Epoch):
        for i, (inputs, labels) in enumerate(train_loader):  # 每次调用一个batch,后台索引
    # 也可以采用next(iter(train_loader)), 读取一个批次
    

      在网络运行时,我们采用enumerate函数,进行迭代,这里会:

    • 进入DataLoader数据装载器;

    • 判断参数,是否采用多进程处理;

    • 调用Sampler函数,根据输入数据个数(由Dataset类中def len()函数得到),随机获取index索引值;

    • 进入我们定义的Dataset类,调用def getitem(),根据index获取数据,返回;

    • 调用collate_fn()函数整理数据,最终得到Batch。

    2.代码(如何将电脑中的数据送入网络?)

    注意:这里数据集已经分类好,文件夹已经各自建立,不包含划分数据的函数!!

    import torch
    from torch.utils.data import Dataset
    import os
    from PIL import Image
    import numpy as np
    import torchvision.transforms as transforms
    
    category = {"0": 0, "1": 1, "1_enhanced": 2, "1_enhanced_2": 3, "0_enhanced_1":4}  # 定义标签,"文件夹名":标签
    
    class my_dataset(Dataset):
        '''根据自己的数据,进行读取,Dataset类创建Pytorch数据集类型'''
        '''
        Args:
            data_dir: 数据地址(训练集、验证集、测试集)
            transform: torchvision.transforms(各种变换、以及Totensor)      
        Return:
            read_data  根据dataloader的索引获取数据
            len(self.data_info)  数据个数
        '''
    
        def __init__(self, data_dir, transform=None):
            self.transforms = transform
            self.data_info = self.get_dataset_info(data_dir)  # 获取所有数据路径和对应的标签,方便dataloader 用index批量处理
    
        def __getitem__(self, index):  # 当dataloader sampler得到index,根据该index索引dataset中数据
            path_data, label = self.data_info[index]
            read_data = Image.open(path_data).convert("RGB")  # PIL-->RGB(0-256)
    
            if self.transforms is not None:
                read_data = self.transforms(read_data)
    
            return read_data, label
    
        def __len__(self):
            return len(self.data_info)
    
        @staticmethod  # 定义该函数为静态类型,不用实例化类也可调用
        def get_dataset_info(data_dir):
            data_info = list()  # 最终包含所有图片、标签(每一行)
            for root, dirs, files in os.walk(data_dir):  # 获取当前文件夹的父目录、当前文件夹下所有文件名、所有内部文件
                for sub_dir in dirs:  # 遍历所有类别
                    each_cate = os.listdir(os.path.join(root, sub_dir))  # os.listdir() 方法用于返回指定的文件夹包含的文件或文件夹的名字的列表。
    
                    for i in range(len(each_cate)):  # 遍历每一个类别下的图片数据,将标签一同嵌入
                        each_data_name = each_cate[i]
                        each_data_path = os.path.join(root, sub_dir, each_data_name)
                        each_label = category[sub_dir]
    
                        data_info.append((each_data_path, int(each_label)))
    
            return data_info
    

    二.数据预处理(torchvision.transforms)

    1.torchvision

    2.transforms.Compose([......])组合

      计算机将按照Compose中定义的transforms操作,依次进行数据处理。

    train_transforms = transforms.Compose([
        transforms.Resize((75, 75)),
        transforms.ToTensor(),  # (H x W x C) [0, 255] to a torch.FloatTensor (C x H x W) [0.0, 1.0]
        transforms.Normalize(mean=norm_mean,std=norm_std)  # 逐通道归一化,注意通道数
    ])
    

    3.各种transforms处理方式

      本节目前仅介绍:标准化Normalize、图像裁剪Crop、旋转翻转。

    (1)数据标准化

    transforms.Normalize(mean, std, inplace=False)  #逐通道对图像进行标准化,mean:(M1,...,Mn) and std: (S1,..,Sn) for n channels
    # input[channel] = (input[channel] - mean[channel]) / std[channel]
    

    (2)裁剪

    a)从中心进行裁剪

    transforms.CenterCrop(size=32)  # 由图像中心进行裁剪,size=32*32
    

    b)随机裁剪

    transforms.RandomCrop(size, padding=None, pad_if_needed=False, fill=0, padding_mode='constant')
    # 先填充再随机裁剪
    # padding:设置填充大小,数值a --> 上下左右填充a个像素,(a,b)--> 左右a上下b, (a,b,c,d) --> 左a上b右c下d
    # padding_mode:填充模式:
          # constant:像素值由fill参数设定;
          # edge:由图像边缘像素决定;
          # reflect:镜像填充,最后一个像素不镜像;
          # symmetric:镜像填充,最后一个像素镜像。
    

    c)随机面积、随机长宽比裁剪图片

    transforms.RandomResizedCrop(size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BILINEAR)
    # 先选择scale,再ratio,再判断size,是否需要interpolation进行resized
    # scale=(0.08, 1.0):随机裁剪面积比例,范围内随机选
    # ratio=(3. / 4., 4. / 3.):随机长宽比
    # interpolation:插值方法
    

    d)上下左右中心随机裁剪5张图片

    transforms.FiveCrop(size)  # 从上下左右中心各裁剪出五张图片
    transforms.TenCrop(size, vertical_flip=False)  # 先进行FiveCrop(),再对五张图片进行水平/垂直镜像,获得10张图片
    

    注意:这里返回的是tuple()类型,需要按行拼接起来,送入下游transforms处理。

    >>> transform = Compose([
             >>>    TenCrop(size), # this is a list of PIL Images
             >>>    Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor
             >>> ])
             >>> #In your test loop you can do the following:
             >>> input, target = batch # input is a 5d tensor, target is 2d
             >>> bs, ncrops, c, h, w = input.size()
             >>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops
             >>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops
    

    有问题:当采用数据增强时,一方面采用TenCrop形式,另一方面采用其他数据变换,一同送入Dataloader时会产生错误,因为维度不一致,其他数据变换在dataset中为三维(channel,H,W),而TenCrop却是四维(ncrops,channel,H,W),于是当迭代获取Batch时会由于维度不匹配程序报错。
    解决方法:【等后续找到再来写,手动狗头微笑】

    (3)翻转、旋转

    transforms.RandomHorizontalFlip(p=0.5)  # 依概率进行水平(左右)翻转
    transforms.RandomVerticalFlip(p=0.5)  # 依概率进行垂直(上下)翻转
    transforms.RandomRotation(degrees, resample=False, expand=False, center=None)  # 随机旋转图片
          # degrees:旋转角度,若为a,则在(-a,a)之间二选一,若为(a, b),则二选一
          # expand:是否扩大图片(因为旋转过后可能会丢失图片某一块),仅针对中心点旋转
          # center:旋转点设置,默认中心点
    
    

    (4)对各种变换的组合--》选择操作(如RandomChoice)

    transforms.RandomChoice([transforms1, transforms2, ......])  # 随机挑选一个
    transforms.RandomApply([transforms1, transforms2, ......], p=0.5)  # 依概率执行整个一组(要么执行,要么不执行)
    transforms.RandomOrder([transforms1, transforms2, ......])  # 对一组操作进行打乱顺序,再去执行这一组
    

    4.自定义Transforms方法

    class YourTransforms(object):
        def __init__(self,Arg1,Arg2):
            '''传参数'''
        def __call__(self, x):
            '''定义该Transforms方法'''
            return x
    
  • 相关阅读:
    JAVA EE社团管理升级版-项目展示(微信小程序)
    JAVA EE社团管理升级版-微信WEB管理端说明文档
    python爬虫19 | 遇到需要的登录的网站怎么办?用这3招轻松搞定!
    python爬虫20 | 小帅b教你如何使用python识别图片验证码
    python爬虫16 | 你,快去试试用多进程的方式重新去爬取豆瓣上的电影
    python爬虫17 | 听说你又被封 ip 了,你要学会伪装好自己,这次说说伪装你的头部
    python爬虫18 | 就算你被封了也能继续爬,使用IP代理池伪装你的IP地址,让IP飘一会
    python爬虫15 | 害羞,用多线程秒爬那些万恶的妹纸们,纸巾呢?
    python爬虫13 | 秒爬,这多线程爬取速度也太猛了,这次就是要让你的爬虫效率杠杠的
    python爬虫14 | 就这么说吧,如果你不懂python多线程和线程池,那就去河边摸鱼!
  • 原文地址:https://www.cnblogs.com/zpc1001/p/13151254.html
Copyright © 2011-2022 走看看