zoukankan      html  css  js  c++  java
  • 0702-计算机视觉工具包torchvision

    0702-计算机视觉工具包torchvision

    pytorch完整教程目录:https://www.cnblogs.com/nickchen121/p/14662511.html

    一、torchvision 概述

    计算机视觉是深度学习中最重要的一类应用,为了方便研究者使用,torch 专门开发了一个视觉工具包 torchvision,这个包独立于 torch,需要使用 pip install torchvision 进行安装。

    之前的我们已经使用过它的部分功能,在这里我们在做一个系统的介绍,它主要包含以下三个功能:

    • models:提供深度学习中各种经典网络的网络结构以及训练好的模型,包括 Alex-Net、VGG 系列、ResNet 系列、Inception 系列等
    • datasets:提供常用的数据集加载,设计上都是集成 torch.utils.data.Dataset,主要包括 MNIST、CIFAR10/100、ImageNet、COCO 等
    • transforms:提供常用的数据预处理操作,主要包括对 Tensor 以及 PIL Image 对象的操作

    二、通过 torchvision 加载模型

    from torchvision import models
    from torch import nn
    
    # 加载预训练好的模型,如果不存在会下载
    # 预训练好的模型保存在 ~/.torch/modes/ 下面
    resnet34 = models.resnet34(pretrained=True, num_classes=1000)
    
    # 修改最后的全连接层为 10 分类问题(默认是 ImageNet 上的 1000 分类)
    resnet34.fc = nn.Linear(512, 10)
    

    三、通过 torchvision 加载并处理数据集

    from torchvision import datasets
    from torchvision import transforms as T
    # 指定数据集路径为 data,如果数据集不存在则进行下载
    # 通过 train=False 获取测试集
    
    normalize = T.Normalize(mean=[0.4, 0.4, 0.4], std=[0.2, 0.2, 0.2])
    transform = T.Compose([
        T.RandomResizedCrop(224),
        T.RandomHorizontalFlip(),
        T.ToTensor(),  # 把图片转成 Tensor,归一化至 [0,1]
        T.Lambda(lambda x: x.repeat(3, 1, 1)),  # 把图片转为 3 通道的
        normalize,
    ])
    
    dataset = datasets.MNIST('data/',
                             download=True,
                             train=False,
                             transform=transform)
    

    Transforms 中涵盖了大部分对 Tensor 和 PIL Image 的常用处理,这个转换通常分为两步:

    1. 第一步:构建转换操作,例如 transf = transforms.Normalize(mean=x, std=y)
    2. 第二步:执行转换操作,例如 otuput = transf(inp)
    import torch as t
    
    # 构建随机噪声,图片如下图所示
    to_pil = T.ToPILImage()
    to_pil(t.rand(3, 64, 64))
    

    四、通过 torchvision 拼接并保存图片

    torchvision 还提供了两个常用的函数:

    1. make_grid,它能把多张图片拼接在一个网格中
    2. save_img,它能把 Tensor 保存成图片
    len(dataset)
    
    10000
    
    from torch.utils.data import DataLoader
    
    dataloader = DataLoader(dataset, shuffle=True, batch_size=16)
    from torchvision.utils import make_grid, save_image
    dataiter = iter(dataloader)
    dataiter
    img = make_grid(next(dataiter)[0], 4)  # 拼接成 4*4 网格图片,并且会转成 3 通道,如下图所示
    to_img = T.ToPILImage()
    to_img(img)
    
    save_image(img, 'a.png')
    from PIL import Image
    Image.open('a.png')
    
  • 相关阅读:
    解决EXC_BAD_ACCESS错误的一种方法--NSZombieEnabled
    关于deselectRowAtIndexPath
    CGRectInset、CGRectOffset、等对比整理
    代码设置UITableViewCell 各个组件间距
    UITableViewCell计算行高
    设置UITableView中UIImage的大小
    UILbale自动换行
    根据字体多少使UILabel自动调节尺寸
    ios通过url下载显示图片
    Python【requests】第三方模块
  • 原文地址:https://www.cnblogs.com/nickchen121/p/14711902.html
Copyright © 2011-2022 走看看