zoukankan      html  css  js  c++  java
  • PyTorch学习记录004-torchvision

    0.模块

      torchvision有4个功能模块:model、datasets、transforms和utils。利用datasets可以下载一些经典数据集,本次笔记主要记录如何使用datasets的ImageFolder处理自定义数据集,以及如何使用transforms对源数据进行预处理、增强等。
    

    1.transforms

      transforms提供了对PIL Image对象和Tensor对象的常用操作。
    

    1)对PIL Image的常见操作如下。

      Scale/Resize:调整尺寸,长宽比保持不变。
      CenterCrop、RandomCrop、RandomSizedCrop:裁剪图片,CenterCrop和RandomCrop在crop时是固定size,RandomResizedCrop则是random size的crop。
      Pad:填充。
      ToTensor:把一个取值范围是[0,255]的PIL.Image转换成Tensor。形状为(H,W,C)的Numpy.ndarray转换成形状为[C,H,W],取值范围是[0,1.0]的torch.FloatTensor。
      RandomHorizontalFlip:图像随机水平翻转,翻转概率为0.5。
      RandomVerticalFlip:图像随机垂直翻转。
      ColorJitter:修改亮度、对比度和饱和度。
    

    2)对Tensor的常见操作如下。

      Normalize:标准化,即,减均值,除以标准差。
      ToPILImage:将Tensor转为PIL Image。
      如果要对数据集进行多个操作,可通过Compose将这些操作像管道一样拼接起来,类似于nn.Sequential。以下为示例代码:
      这个东西会被送入你自定义的Dataset中!
    
    transforms.Compose([
    	# 将给定的PIL.Image进行中心切割,得到给定的size
    	# size可以是tuple,(target_height, target_width)
    	# size也可以是一个Integer, 切出来一个正方形
    	transform.CenterCrop(10)
    	# 切割中心点的位置随机选取
    	transforms.RandomCrop(20, padding=0)
    	# 将一个取值范围是[0,255]的PIL.Image或者shape为(H,W,C)的numpy.ndarray
    	# 转换为形状为(C,H,W),取值范围是[0,1]的torch.FloatTensor
    	transforms.ToTensor()
    	# 规范化到[-1, -1]
    	transforms.Normalize(mean = (0.5, 0.5, 0.5), std = (0.5,0.5,0.5))
    ])
    
    

    2.datasets.ImageFolder

      当文件依据标签处于不同文件下时,如:
    

      我们可以利用torchvision.datasets.ImageFolder来直接构造出dataset
    
    loader = datasets.ImageFolder(path)
    loader = data.DataLoader(dataset)
    
      ImageFolder会将目录中的文件夹名自动转化成序列,当DataLoader载入时,标签自动就是整数序列了。
      下面我们利用ImageFolder读取不同目录下的图片数据,然后使用transforms进行图像预处理,预处理有多个,我们用compose把这些操作拼接在一起。然后使用DataLoader加载。对处理后的数据用torchvision.utils中的save_image保存为一个png格式文件,然后用Image.open打开该png文件,详细代码如下:
    
    from torchvision import transforms, utils
    from torchvision import datasets
    from torch.utils.data import DataLoader
    
    import matplotlib.pyplot as plt
    
    
    my_trans = transforms.Compose([
        transforms.RandomResizedCrop(224), #将给定图像随机裁剪为不同的大小和宽高比,然后缩放所裁剪得到的图像为制定的大小
        transforms.RandomHorizontalFlip(), #图像水平翻转
        transforms.ToTensor()
    ])
    
    train_data = datasets.ImageFolder(r'./data/torchvision_data', transform = my_trans)
    train_loader = DataLoader(train_data, batch_size=8, shuffle=True)
    for i_batch, img in enumerate(train_loader):
        if i_batch == 0:
            print(img[1])
            fig = plt.figure()
            grid = utils.make_grid(img[0])
            plt.imshow(grid.numpy().transpose((1, 2, 0)))
            plt.show()
            utils.save_image(grid,'test02.png')
        break
    
    
      这里我建立一个torchvision_data文件夹,把不同类型的图片放在不同的子文件夹下。
    

    运行结果为:

      [参考](https://blog.csdn.net/qq_39610915/category_10487496.html)
  • 相关阅读:
    Eclipse安装TestNG插件
    总结Selenium WebDriver中一些鼠标和键盘事件的使用
    【资料收集】AutomationGuru
    centos7.4 yum安装包出现网络不可达跟Recv failure: Connection reset by peer" 这个问题
    ubuntu配置ntp
    OpenStack-ansible ubuntu16.04安装&& centos7 安装 && openSUSE 安装OpenStack-ansible
    HSRP&&STP&&ACL
    vlan通讯&&动态路由
    cisco交换机基本配置
    cisco教程 怎么改console密码 主机名 各种模式的切换等
  • 原文地址:https://www.cnblogs.com/wzdszh/p/14109553.html
Copyright © 2011-2022 走看看