torchvision.transforms
先看看一个transform是干啥的:
transform通常用于处理图像,它将图像进行一定的变换(transform),具体来说有:
class torchvision.transforms.ToTensor
把一个取值范围是[0,255]
的PIL.Image
或者shape
为(H,W,C)
的numpy.ndarray
,转换成形状为[C,H,W]
,取值范围是[0,1.0]
的torch.FloadTensor,例如
data = np.random.randint(0, 255, size=300)
img = data.reshape(10,10,3)
print(img.shape)
img_tensor = transforms.ToTensor()(img) # 转换成tensor
print(img_tensor)
class torchvision.transforms.Normalize(mean, std)
给定均值:(R,G,B)
方差:(R,G,B)
,将会把Tensor
正则化。即:Normalized_image=(image-mean)/std
class torchvision.transforms.Resize(size, interpolation=2)
将输入的`PIL.Image`重新改变大小成给定的`size`,`size`是最小边的边长。举个例子,如果原图的`height>width`,那么改变大小后的图片大小是`(size*height/width, size)`。interpolation是插值方式,默认为PIL.Image.BILINEAR。
然后可以将这些变换集成为一个:
class torchvision.transforms.Compose(transforms)
将多个transform
组合起来使用。