zoukankan      html  css  js  c++  java
  • torchvision 之 transforms 模块详解

    torchvision 是独立于 PyTorch 的关于图像操作的一个工具库,目前包括六个模块:

       1)torchvision.datasets:几个常用视觉数据集,可以下载和加载,以及如何编写自己的 Dataset。

       2)torchvision.models:经典模型,例如 AlexNet、VGG、ResNet 等,以及训练好的参数。

       3)torchvision.transforms:常用的图像操作,例随机切割、旋转、数据类型转换、tensor 与 numpy 和 PIL Image 的互换等。

       4)torchvision.ops:提供 CV 中常用的一些操作,比如 NMS、ROI_Align、ROI_Pool 等。

       5)torchvision.io:提供输入输出的一些操作,目前针对的是视频的写入写出。

       6)torchvision.utils:其他工具,比如产生一个图像网格等。

    这里主要介绍下 torchvision.transforms 模块。torchvision.transforms 是 pytorch 中的图像预处理包。一般用 Compose 把多个步骤整合到一起。

    """
    transforms: list of Transform objects, 是一个列表
    """
    class torchvision.transforms.Compose(transforms)
    

     事实上,Compose()类会对 transforms 列表里面的 transform 操作进行遍历。实现的代码很简单,截取部分源码如下:

    def __call__(self, img):
        for t in self.transforms:   
            img = t(img)
        return img
    

    transforms 中的常见图像操作:

    1. transforms.ToTensor 

       Converts a PIL Image or numpy.ndarray (H x W x C) in the range [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]。

       这个变换改变了图像的参数顺序,最终得到的图像形状为 $(C,H,W)$,并转换为 Tensor 类型,归一化至 [0,1] 是直接除以 255,每个像素变成一个 32 位

       的浮点类型。

    2. transforms.Normalize

    """
    mean (sequence) – Sequence of means for each channel.
    std (sequence) – Sequence of standard deviations for each channel.
    """
    torchvision.transforms.Normalize(mean, std)

       当数据量很大的时候,每个通道的数据都可以看成正态分布(大数定律),求出每个通道数据对应的均值和标准差,然后利用这两个值将每个通道数据的分布

       转换为标准正态分布。

  • 相关阅读:
    地图篇-02.地理编码
    地图篇-01.获取用户位置
    新手教程之使用Xib自定义UITableViewCell
    封装
    NSDate简单介绍
    OC知识点归纳
    Xcode的控制台调试命令
    [开发笔记]UIApplication介绍
    技术分享-开发利器block底层实现
    技术分享-开发利器block
  • 原文地址:https://www.cnblogs.com/yanghh/p/14089302.html
Copyright © 2011-2022 走看看