torch.unsqueeze()和torch.squeeze()
1.torch.unsqueeze()
原型:
torch.unsqueeze(input, dim, out=None)
作用:扩展维度,返回一个新的张量,对输入徳既定位置插入维度1。
参数:
tensor (Tensor)
– 输入张量dim (int)
– 插入维度的索引out (Tensor, optional)
– 结果张量
2.torch.squeeze()
原型:
torch.squeeze(input, dim=None, out=None)
作用:降低维度,将输入张量形状中的1 去除并返回。 如果输入是形如(A×1×B×1×C×1×D),那么输出形状就为: (A×B×C×D)
当给定dim时,那么挤压操作只在给定维度上。例如,输入形状为: (A×1×B), squeeze(input, 0)
将会保持张量不变,只有用 squeeze(input, 1)
,形状会变成 (A×B)。
参数:
input (Tensor)
– 输入张量dim (int, optional)
– 如果给定,则input只会在给定维度挤压out (Tensor, optional)
– 输出张量