zoukankan      html  css  js  c++  java
  • PyTorch数据处理,datasets、DataLoader及其工具的使用

    torchvision是PyTorch的一个视觉工具包,提供了很多图像处理的工具。

    datasets使用ImageFolder工具(默认PIL Image图像),获取定制化的图片并自动生成类别标签。如裁剪、旋转、标准化、归一化等(使用transforms工具)。

    DataLoader可以把datasets数据集打乱,分成batch,并行加速等。

    一、datasets获取原图或格式化的图,自动命名标签

    1.1 获取原图片

    使用torchvision.datasets中的ImageFolder工具,功能:

    1、文件夹名就是类别名

    2、从上到下自动为文件夹自动创建标签,0、1、2、...。class_to_idx、imgs属性可以查看。

    3、返回每一幅图的data、label

     

    from torchvision.datasets import ImageFolder
    
    dataset=ImageFolder("E:/data/dogcat_2/train/") #获取路径,返回的是所有图的data、label
    print(dataset.class_to_idx) #查看类别名,及对应的标签。
    print(dataset.imgs)  #查看路径里所有的图片,及对应的标签

    print(dataset[0][1]) #第1张图的label
    dataset[0][0] #第1张图的data

    1.2 获取定制化的图片,启用ImageFolder的transform参数

    使用torchvision的transforms工具,常用功能:

    Resize——调整大小
    CenterCrop、RandomCrop、RandomSizedCrop——裁剪
    Pad——填充
    ToTensor——PIL Image转Tensor,自动[0,255]归一化到[0,1]
    Normalize——标准化,即减均值,除以标准差
    ToPILImage——Tensor转PIL Image

    这些操作可以放到一起——Compose

    from torchvision import transforms as T
    
    #设置格式化条件
    transform=T.Compose([T.Resize((200,200)), #缩放为200*200方形
                         T.RandomHorizontalFlip(), #水平翻转
                         T.ToTensor(), #PIL Image转Tensor,[0,255]自动归一化为[0,1]
                         T.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5]) #标准化,减均值除标准差
                        ])
    #启用ImageFolder的transform参数,获取格式化图像
    dataset=ImageFolder("E:/data/dogcat_2/train/",transform=transform)
    
    dataset[0][0].size() #查看图像大小,3*224*224

    #展示图像,乘标准差加均值,再转回PIL Image(上述过程的逆过程)
    show=T.ToPILImage()
    show(dataset[0][0]*0.5+0.5)

    二、DataLoader处理datasets

    from torch.utils.data import DataLoader
    dataloader=DataLoader( dataset,batch_size=4,shuffle=True,num_workers=2 ) #4幅图为1个batch,打乱,2个进程加速
    #### 显示第1个batch的4幅图(随机)
    from torchvision.transforms import ToPILImage
    from torchvision.utils import make_grid
    dataiter = iter(dataloader) #DataLoader是可迭代的
    (images, labels) = dataiter.next() #第一个batch
    print(labels) #打印标签
    show=ToPILImage() 
    show(make_grid(images*0.5+0.5)).resize((4*100,100))  #以100*100展示第一个batch

  • 相关阅读:
    MQ
    redis
    MongoDB
    进程相关命令
    catalina.sh
    tomcat-jvm
    中间件简介
    websphere
    mysql
    shell变量与字符串操作
  • 原文地址:https://www.cnblogs.com/xixixing/p/12759849.html
Copyright © 2011-2022 走看看