Tensor的组合与分块
组合操作是指将不同的Tensor叠加起来, 主要有torch.cat()和torch.stack()两个函数。 cat即concatenate的意思, 是指沿着已有的数据的某一维度进行拼接, 操作后数据的总维数不变, 在进行拼接时, 除了拼接的维度之外, 其他维度必须相同。 而torch.stack()函数指新增维度, 并按照指定的维度进行叠加,
1 import torch 2 3 # 创建两个2×2的Tensor 4 a = torch.Tensor([[1,2],[3,4]]) 5 print(a,a.shape) 6 7 b = torch.Tensor([[5,6],[7,8]]) 8 print(b,b.shape) 9 10 # 以第一维进行拼接, 则变成4×2的矩阵 11 c = torch.cat([a,b],0) 12 print(c,c.shape) 13 14 # 以第二维进行拼接, 则变成2*4的矩阵 15 d = torch.cat([a,b],1) 16 print(d,d.size())
结果输出:
1 tensor([[1., 2.], 2 [3., 4.]]) torch.Size([2, 2]) 3 tensor([[5., 6.], 4 [7., 8.]]) torch.Size([2, 2]) 5 tensor([[1., 2.], 6 [3., 4.], 7 [5., 6.], 8 [7., 8.]]) torch.Size([4, 2]) 9 tensor([[1., 2., 5., 6.], 10 [3., 4., 7., 8.]]) torch.Size([2, 4])
1 import torch 2 3 # 创建两个2×2的Tensor 4 a = torch.Tensor([[1,2],[3,4]]) 5 print(a,a.shape) 6 7 >> tensor([[1., 2.], 8 [3., 4.]]) torch.Size([2, 2]) 9 10 b = torch.Tensor([[5,6],[7,8]]) 11 print(b,b.shape) 12 13 >> tensor([[5., 6.], 14 [7., 8.]]) torch.Size([2, 2]) 15 16 # 以第0维进行stack, 叠加的基本单位为序列本身, 即a与b, 因此输出[a, b], 输出维度为2×2×2 17 d=torch.stack([a,b],0) 18 print(d, d.size()) 19 >> tensor([[[1., 2.], 20 [3., 4.]], 21 22 [[5., 6.], 23 [7., 8.]]]) torch.Size([2, 2, 2]) 24 25 # 以第1维进行stack, 叠加的基本单位为每一行, 输出维度为2×2×2 26 e=torch.stack([a,b],1) 27 print(e, e.shape) 28 29 >> tensor([[[1., 2.], 30 [5., 6.]], 31 32 [[3., 4.], 33 [7., 8.]]]) torch.Size([2, 2, 2]) 34 35 # 以第2维进行stack, 叠加的基本单位为每一行的每一个元素, 输出维度为2×2×2 36 f=torch.stack([a,b],2) 37 print(f, f.shape) 38 39 >> tensor([[[1., 5.], 40 [2., 6.]], 41 42 [[3., 7.], 43 [4., 8.]]]) torch.Size([2, 2, 2])
分块则是与组合相反的操作, 指将Tensor分割成不同的子Tensor,主要有torch.chunk()与torch.split()两个函数, 前者需要指定分块的数量,而后者则需要指定每一块的大小, 以整型或者list来表示。 具体示例如下 :
1 import torch 2 3 a = torch.Tensor([[1,2,3], [4,5,6]]) 4 print(a, a.size()) 5 >> tensor([[1., 2., 3.], 6 [4., 5., 6.]]) torch.Size([2, 3]) 7 8 # 使用chunk, 沿着第0维进行分块, 一共分两块, 因此分割成两个1×3的Tensor 9 b = torch.chunk(a, 2, 0) 10 print(b) 11 >> (tensor([[1., 2., 3.]]), tensor([[4., 5., 6.]])) 12 13 # 沿着第1维进行分块, 因此分割成两个Tensor, 当不能整除时, 最后一个的维数会小于前面的 14 # 因此第一个Tensor为2×2, 第二个为2×1 15 c = torch.chunk(a, 2, 1) 16 print(c) 17 >> (tensor([[1., 2.], 18 [4., 5.]]), tensor([[3.], 19 [6.]])) 20 21 # 使用split, 沿着第0维分块, 每一块维度为2, 由于第一维维度总共为2, 因此相当于没有分割 22 d = torch.split(a, 2, 0) 23 print(d) 24 >> (tensor([[1., 2., 3.], 25 [4., 5., 6.]]),) 26 27 # 沿着第1维分块, 每一块维度为2, 因此第一个Tensor为2×2, 第二个为2×1 28 e = torch.split(a, 2, 1) 29 print(e) 30 >> (tensor([[1., 2.], 31 [4., 5.]]), tensor([[3.], 32 [6.]])) 33 34 # split也可以根据输入的list进行自动分块, list中的元素代表了每一个块占的维度 35 f = torch.split(a, [1,2], 1) 36 print(f) 37 >> (tensor([[1.], 38 [4.]]), tensor([[2., 3.], 39 [5., 6.]]))