写到前面
这是torchvision.utils
模块里面的两个方法,因为比较常用,所以pytorch
直接封装好了。
制作网格
网络图像一般用于训练数据或测试数据的可视化。
torchvision.utils.make_grid(tensor, nrow, padding) → torch.Tensor
- 描述
将多张tensor
格式的图像以网格的方式封装到一起。
- 参数
tensor
(tensor or list):四维 (B x C x H x W)
mini-batch的tensor
数据或者是包含同一尺寸的图片列表。
nrow
(int):网格每行图片的个数,默认是8
;千万不要理解为图片的行数。
padding
(int):四周填充的宽度,默认是2
,你可以理解为网格中图片之间的间距。默认填充值是0
,也就是黑色。
注:这是三个比较常用的参数,其它参数请参考官方文档。
- 示例
# 以mnist数据集为例,train_loader的batch_size设置为9
images, labels = next(iter(train_loader))
print(images.size()) # torch.Size([9, 1, 28, 28])
images = torchvision.utils.make_grid(images, 3, 0)
print(images.size()) # torch.Size([3, 84, 84])
- 绘图
保存本地
tensor
数据类型保存时不用再转为PIL.Image
或numpy.ndarray
,pytorch
直接给我们写好了一个方法。
torchvision.utils.save_image(tensor, fp) → None
- 描述
直接将tensor数据保存为图像。
- 参数
tensor
(Tensor or list):待保存的tensor
数据。如果给以一个四维的mini-batch
的tensor
,将调用网格方法,然后再保存到本地。
fp
(string or file object)):图像的保存路径。
注:这是两个比较常用的参数,其它参数请参考官方文档。
- 示例
images, labels = next(iter(train_loader))
print(images.size()) # torch.Size([9, 1, 28, 28])
images = torchvision.utils.make_grid(images, 3, 0)
print(images.size()) # torch.Size([3, 84, 84])
torchvision.utils.save_image(images, 'test.jpg')
完整代码
#%% 导入模块
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import make_grid, save_image
#%% 下载数据集
train_file = datasets.MNIST(
root='./dataset/',
train=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]),
download=True
)
#%% 制作数据加载器
train_loader = DataLoader(
dataset=train_file,
batch_size=9,
shuffle=True
)
#%% 训练数据可视化
images, labels = next(iter(train_loader))
print(images.size()) # torch.Size([9, 1, 28, 28])
images = make_grid(images, 3, 0)
print(images.size()) # torch.Size([3, 84, 84])
save_image(images, 'test.jpg')