转自:https://blog.csdn.net/xiexu911/article/details/80820028
1.torch.squeeze()
只会去掉维度为1的那个维度。它只会去掉维度为1的维度,像下面的没有为1的维度,就不会改变:
>>> aaa=np.ones((3,2)) >>> aaa array([[1., 1.], [1., 1.], [1., 1.]]) >>> aaa.squeeze() array([[1., 1.], [1., 1.], [1., 1.]])
https://docs.scipy.org/doc/numpy/reference/generated/numpy.squeeze.html
2.torch.unsqueeze()
>>> a=torch.tensor([1,2]) >>> a.size() torch.Size([2]) >>> b=a.unsqueeze(1) >>> b tensor([[1], [2]]) >>> b.size() torch.Size([2, 1]) >>> c=a.unsqueeze(0) >>> c tensor([[1, 2]]) >>> c.size() torch.Size([1, 2])
就是在第i个维度上多加一个维度,对于b来说,是第二个维度,对于c来说,是第一个维度。
原来是[2]长的list那种形式,在unsqueeze(0)的时候就变成了[1,2],在unsqueeze(1)的时候就变成了[2,1]