zoukankan      html  css  js  c++  java
  • 【深度学习框架】使用PyTorch进行数据处理

      在深度学习中,数据的处理对于神经网络的训练来说十分重要,良好的数据(包括图像、文本、语音等)处理不仅可以加速模型的训练,同时也直接关系到模型的效果。本文以处理图像数据为例,记录一些使用PyTorch进行图像预处理和数据加载的方法


    一、数据的加载

      在PyTorch中,数据加载需要自定义数据集类,并用此类来实例化数据对象,实现自定义的数据集需要继承torch.utils.data包中的Dataset类
      在继承Dataset实现自己的类时,需要实现以下两个Python魔法方法:

    • __getitem__(index): 返回一个样本数据,当使用obj[index]时实际就是在调用obj.__getitem__(index)
    • __len__():返回样本的数量,当使用len(obj)时实际就是在调用obj.__len__()

      例如,以猫狗大战的二分类数据集为例,其加载过程如下:

    import os
    import torch as t
    from torch.utils import data
    from PIL import Image
    import numpy as np
    
    class dogCat(data.Dataset):
        def __init__(self,root): # root为数据存放目录
            imgs = os.listdir(root) #列出当前路径下所有的文件
            self.imgs = [os.path.join(root,img) for img in imgs] # 所有图片的路径
            #print(self.imgs)
    
    	"""返回一个样本数据"""
        def __getitem__(self, item): 
            img_path = self.imgs[item] # 第item张图片的路径
            #dog 1 cat 0
            label = 1 if 'dog' in img_path.split('\')[-1] else 0 # 获取标签信息
            #print(label)
            pil_img = Image.open(img_path) #读入图片
            print(type(pil_img))
            array = np.asarray(pil_img) # 转为numpy.array类型
            data = t.from_numpy(array) # 转为tensor类型
            return data,label #返回图片对应的tensor及其标签
    
    	"""样本的数量"""
        def __len__(self):
            return len(self.imgs)
    
    if __name__ == '__main__':
        dogcat = dogCat('D:pycodedogsVScatsdatacatvsdog\train') #数据集对象
        data,label = dogcat[0] # 返回第0张图片的信息
        print(data.size())
        print(label)
        print(len(dogcat))
    

    二、计算机视觉工具包:torchvision

      对于图像数据来说,以上的数据加载时不完善的,因为只是将图片读入,而没有进行相关的处理,如每张图片的大小和形状,样本的数值归一化等等。
      为了解决这一问题,PyTorch开发了一个视觉工具包torchvision,这个包独立于torch,需要通过pip install torchvision来单独安装。
      torchvision有三个部分组成:

    • models提供各种经典的网络结构和预训练好的模型,如AlexNet、VGG、ResNet、Inception等
    from torchvision import models
    from torch import nn
    resnet34 = models.resnet34(pretrained=True,num_classes=1000) # 加载预训练模型
    resnet34.fc=nn.Linear(512,10) # 修改全连接层为10分类
    
    • datasets提供了常用的数据集,如MNIST、CIFAR10/100、ImageNet、COCO等
    from torchvision import datasets
    dataset = datasets.MNIST('data/',download=True,train=False,transform=transform)
    

      除了常用数据集外,需要特别注意的是ImageFolder,ImageFolder假设所有的文件按文件夹存放,每个文件夹下面存储同一类的图片,文件夹的名字为这一类别的名字。这是我们经常用到的一种数据组织形式。

    # 使用方法:
    ImageFolder(root,transform=None,target_transform=None,loader=default_loader)
    # 参数:文件夹路径,对图像做什么样的转换,对标签做什么样的转换,如何加载图片
    
    from torchvision.datasets import ImageFolder
    dataset = ImageFolder('data\')
    print(dataset.class_to_idx) # class_to_idx ,label和id的对应关系,从0开始
    print(dataset.imgs) # 数据和标签对应
    
    • transforms: 提供常用的数据预处理操作,主要是对Tensor和PIL Image对象的处理操作

      对PIL Image的操作:Resize、CenterCrop、RandomCrop、RandomsizedCrop、Pad、ToTensor等。

      对Tensor的操作:Normalize、ToPILImage等。

      如果要进行多个操作,可以通过transforms.Compose([])将操作拼接起来。但是需要注意的是需要首先构建转换操作,然后再执行转换操作。

    import os
    from torch.utils import data
    from PIL import Image
    import numpy as np
    from torchvision import transforms as T
    
    transform = T.Compose([T.Resize(224),T.CenterCrop(224),T.ToTensor(),T.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5])])  # 构建转换操作
    
    class dogCat(data.Dataset):
        def __init__(self,root,transforms):
            imgs = os.listdir(root)
            #print(imgs)
            self.imgs = [os.path.join(root,img) for img in imgs]
            #print(self.imgs)
            self.transforms = transforms
    
        def __getitem__(self, item):
            img_path = self.imgs[item]
            #dog 1 cat 0
            label = 1 if 'dog' in img_path.split('\')[-1] else 0
            #print(label)
            pil_img = Image.open(img_path)
            if self.transforms:
                pil_img = self.transforms(pil_img)  #执行准换操作
            return pil_img,label,item
    
        def __len__(self):
            return len(self.imgs)
    
    

    三、使用DataLoader进行数据再处理

      通过上述描述,我们通过自定义数据集类,使用视觉工具包进行图像的转换等操作,最终得到的是一个dataset的数据集对象,使用此对象可以一次返回一个样本。
      但是,我们应该清楚:训练神经网络时,一般采用的是小批量的梯度下降,因此我们是对一批数据进行处理,也就是一个batch,同时,数据还需要进行打乱(shuffle)和并行加速等。PyTorch提供了DataLoader来实现这些功能。
      DataLoader定义如下:

    DataLoader(dataset,batch_size=1,shuffle=False,sampler=None,num_workers=0,collate_fn=default_collate,pin_memory=False,drop_last=False)
    

      参数含义如下:

    • dataset:加载的数据集
    • batch_zize: 批大小
    • shuffle: 是否将数据打乱
    • sampler:样本抽样,常用的有随机采样RandomSampler,shuffle=True时自动调用随机采样,默认是顺序采样,还有一个常用的是:WeightedRandomSampler,按照样本的权重进行采样。
    • num_workers: 使用的进程数,0代表不使用多进程。
    • collate_fn: 拼接方式。
    • pin_memory: 是否将数据保存在pin memory区。
    • drop_last: 是否将多出来的不足一个batch的丢弃。

      调用DataLoader得到的结果是一个可迭代的对象,可以和使用迭代器一样使用它。

    from torchvision import transforms as T
    from torch.utils.data import DataLoader
    
    transform = T.Compose([T.Resize(224),T.CenterCrop(224),T.ToTensor(),T.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5])])
    
    if __name__ == '__main__':
        dogcat = dogCat('D:pycodedogsVScatsdatacatvsdog\train', transform)
        data, label, index = dogcat[0]
        
        dataloader = DataLoader(dogcat,batch_size=3,shuffle=False,num_workers=0,drop_last=False)
        for batchDatas,batchLabels in dataloader: 
            train()
    

    总结

      本文记录了使用PyTorch进行数据预处理的相关操作流程,重点是掌握Dataset和DataLoader两个类的使用,另外,视觉工具包torchvision的三个模块灵活运用,会对数据处理过程有很好的帮助。

    博学 审问 慎思 明辨 笃行
  • 相关阅读:
    访问修饰符、封装、继承
    面向对象与类
    内置对象
    三级联动 控件及JS简单使用
    asp。net简单的登录(不完整)
    asp。net:html的表单元素:
    ASP.Net简介及IIS服务器及Repeater
    用户控件
    登陆,激活,权限
    timer控件,简单通讯
  • 原文地址:https://www.cnblogs.com/gzshan/p/10628289.html
Copyright © 2011-2022 走看看