zoukankan      html  css  js  c++  java
  • torchvision.transforms模块介绍

    torchvision.transforms模块

    官网地址:https://pytorch.org/docs/stable/torchvision/transforms.html#

    torchvision是独立于PyTorch的关于图像操作的一个工具库,目前包括六个模块:

    • torchvision.datasets:几个常用视觉数据集,可以下载和加载,以及如何编写自己的Dataset。
    • torchvision.models:经典模型,例如AlexNet、VGG、ResNet等,以及训练好的参数。
    • torchvision.transforms:常用的图像操作,例随机切割、旋转、数据类型转换、tensor与numpy 和PIL Image的互换等。
    • torchvision.ops:提供CV中常用的一些操作,比如NMS、ROI_Align、ROI_Pool等。
    • torchvision.io:提供输入输出的一些操作,目前针对的是视频的写入写出。
    • torchvision.utils:其他工具,比如产生一个图像网格等。

    这里主要介绍torchvision.transforms模块。

    torchvision.transforms模块按照功能,可分为5个部分,所有转换均可用torchvision.transforms.Compose() 来组合。

    • Transforms on PIL Image:在PIL Image上进行的转换,比如随机翻转、剪切等。
    • Transforms on torch.Tensor:在tensor上进行的转换,最常用的是归一化操作transforms.Normalize(mean, std, inplace=False)。
    • Conversion Transforms:PIL.Image/numpy.ndarray与Tensor的相互转换。
    • Generic Transforms:提供自定义转换接口。
    • Functional Transforms:不同于前面的转换,这里可以提供更细粒度的控制,需要自己提供随机生成器或指定参数。

    下面重点介绍PIL.Image/numpy.ndarray与Tensor的相互转换,归一化,对PIL.Image进行裁剪、缩放等操作。

    1 PIL.Image/numpy.ndarray与Tensor的相互转换

    PIL.Image/numpy.ndarray转化为Tensor,常常用在训练模型阶段的数据读取,而Tensor转化为PIL.Image/numpy.ndarray则用在验证模型阶段的数据输出。

    from torchvision import transforms
    
    transform1 = transforms.Compose([
        transforms.ToTensor() #PIL Image/ndarray (H,W,C) [0,255] to tensor (C,H,W) [0.0,1.0]
        ]) 
    
    
    ##numpy.ndarray与Tensor的相互转换
    import cv2
    import numpy as np
    
    img_path = 'Lenna.png'
    img1 = cv2.imread(img_path) #img1格式为ndarray  (512,512,3)  uint8  BGR
    img_1 = transform1(img1) #tensor  (3,512,512)  float32  范围是[0.0,1.0]
    #将转换后的tensor还原成ndarray
    img_11 = (img_1.numpy() * 255).astype('uint8')
    img_11 = np.transpose(img_11, (1,2,0))
    #判断两者是否相等 
    print((img1==img_11).all()) #True
    #显示
    cv2.imshow('img_11', img_11)
    cv2.waitKey()
    
    
    ##PIL.Image与Tensor的相互转换
    from PIL import Image
    
    img2 = Image.open(img_path) #为PIL图像对象,即PIL.PngImagePlugin.PngImageFile,默认RGB
    img_2 = transform1(img2) #tensor  (3,512,512)  float32  范围是[0.0,1.0] 
    #将转换后的tensor还原成PIL Image
    img_22 = transforms.ToPILImage()(img_2)  #PIL.Image.Image
    img_22.show()
    

    2 归一化 transforms.Normalize

    transforms.Normalize使用该公式进行归一化:channel = (channel-mean) / std.

    上面的示例中,将transform1改成下面的transform2,即可将tensor数据的范围由[0.0,1.0]归一化到[-1.0, 1.0]

    transform2 = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean = (0.5, 0.5, 0.5), std = (0.5, 0.5, 0.5))
        ])
    

    3 PIL.Image的缩放裁剪等操作

    transforms还提供了裁剪缩放等操作,以便进行数据增强。下面就看一个随机裁剪的例子,这个例子中,仍然使用 Compose 将 transforms 组合在一起。注意,这里对图像的操作主要是针对PIL.Image对象,所以需要先转换成PIL.Image格式。

    transform3 = transforms.Compose([
        transforms.ToTensor(), 
        transforms.ToPILImage(),
        transforms.RandomCrop((300,300)),
        ])
    
    img = Image.open(img_path)
    img3 = transform3(img)
    img3.show()
    

    Reference:

  • 相关阅读:
    面向对象下
    面向对象上
    将博客搬至CSDN
    矩阵的常用术语和操作
    python2.7 Unable to find vcvarsall.bat
    intellij创建maven web项目
    intellij 编译 springmvc+hibernate+spring+maven 找不到hbm.xml映射文件
    spring Thymeleaf 中文乱码
    visualstudiocode 调试electron
    android反编译工具总结
  • 原文地址:https://www.cnblogs.com/inchbyinch/p/12091791.html
Copyright © 2011-2022 走看看