zoukankan      html  css  js  c++  java
  • Tensor的组合与分块-02

    Tensor的组合与分块


      组合操作是指将不同的Tensor叠加起来, 主要有torch.cat()torch.stack()两个函数。 catconcatenate的意思, 是指沿着已有的数据的某一维度进行拼接, 操作后数据的总维数不变, 在进行拼接时, 除了拼接的维度之外, 其他维度必须相同。 而torch.stack()函数指新增维度, 并按照指定的维度进行叠加,

     1 import torch 
     2 
     3 # 创建两个2×2的Tensor
     4 a = torch.Tensor([[1,2],[3,4]])
     5 print(a,a.shape)
     6 
     7 b = torch.Tensor([[5,6],[7,8]])
     8 print(b,b.shape)
     9 
    10 # 以第一维进行拼接, 则变成4×2的矩阵
    11 c = torch.cat([a,b],0)
    12 print(c,c.shape)
    13 
    14 # 以第二维进行拼接, 则变成2*4的矩阵
    15 d = torch.cat([a,b],1)
    16 print(d,d.size())
    View Code

    结果输出:

     1 tensor([[1., 2.],
     2         [3., 4.]]) torch.Size([2, 2])
     3 tensor([[5., 6.],
     4         [7., 8.]]) torch.Size([2, 2])
     5 tensor([[1., 2.],
     6         [3., 4.],
     7         [5., 6.],
     8         [7., 8.]]) torch.Size([4, 2])
     9 tensor([[1., 2., 5., 6.],
    10         [3., 4., 7., 8.]]) torch.Size([2, 4])
    View Code
     1 import torch 
     2 
     3 # 创建两个2×2的Tensor
     4 a = torch.Tensor([[1,2],[3,4]])
     5 print(a,a.shape)
     6 
     7 >>   tensor([[1., 2.],
     8             [3., 4.]]) torch.Size([2, 2])
     9 
    10 b = torch.Tensor([[5,6],[7,8]])
    11 print(b,b.shape)
    12 
    13 >>   tensor([[5., 6.],
    14             [7., 8.]]) torch.Size([2, 2])
    15 
    16 # 以第0维进行stack, 叠加的基本单位为序列本身, 即a与b, 因此输出[a, b], 输出维度为2×2×2
    17 d=torch.stack([a,b],0)
    18 print(d, d.size())
    19 >>  tensor([[[1., 2.],
    20          [3., 4.]],
    21 
    22         [[5., 6.],
    23          [7., 8.]]]) torch.Size([2, 2, 2])
    24 
    25 # 以第1维进行stack, 叠加的基本单位为每一行, 输出维度为2×2×2
    26 e=torch.stack([a,b],1)
    27 print(e, e.shape)
    28 
    29 >> tensor([[[1., 2.],
    30          [5., 6.]],
    31 
    32         [[3., 4.],
    33          [7., 8.]]]) torch.Size([2, 2, 2])
    34 
    35 # 以第2维进行stack, 叠加的基本单位为每一行的每一个元素, 输出维度为2×2×2
    36 f=torch.stack([a,b],2)
    37 print(f, f.shape)
    38 
    39 >> tensor([[[1., 5.],
    40          [2., 6.]],
    41 
    42         [[3., 7.],
    43          [4., 8.]]]) torch.Size([2, 2, 2])
    View Code

       分块则是与组合相反的操作, 指将Tensor分割成不同的子Tensor,主要有torch.chunk()torch.split()两个函数, 前者需要指定分块的数量,而后者则需要指定每一块的大小, 以整型或者list来表示。 具体示例如下 :

     1 import torch 
     2 
     3 a = torch.Tensor([[1,2,3], [4,5,6]])
     4 print(a, a.size())
     5 >> tensor([[1., 2., 3.],
     6         [4., 5., 6.]]) torch.Size([2, 3])
     7 
     8 # 使用chunk, 沿着第0维进行分块, 一共分两块, 因此分割成两个1×3的Tensor
     9 b = torch.chunk(a, 2, 0)
    10 print(b)
    11 >> (tensor([[1., 2., 3.]]), tensor([[4., 5., 6.]]))
    12 
    13 # 沿着第1维进行分块, 因此分割成两个Tensor, 当不能整除时, 最后一个的维数会小于前面的
    14 # 因此第一个Tensor为2×2, 第二个为2×1
    15 c = torch.chunk(a, 2, 1)
    16 print(c)
    17 >> (tensor([[1., 2.],
    18         [4., 5.]]), tensor([[3.],
    19         [6.]]))
    20 
    21 # 使用split, 沿着第0维分块, 每一块维度为2, 由于第一维维度总共为2, 因此相当于没有分割
    22 d = torch.split(a, 2, 0)
    23 print(d)
    24 >> (tensor([[1., 2., 3.],
    25         [4., 5., 6.]]),)
    26 
    27 # 沿着第1维分块, 每一块维度为2, 因此第一个Tensor为2×2, 第二个为2×1
    28 e = torch.split(a, 2, 1)
    29 print(e)
    30 >> (tensor([[1., 2.],
    31         [4., 5.]]), tensor([[3.],
    32         [6.]]))
    33  
    34 # split也可以根据输入的list进行自动分块, list中的元素代表了每一个块占的维度
    35 f = torch.split(a, [1,2], 1)
    36 print(f)
    37 >> (tensor([[1.],
    38         [4.]]), tensor([[2., 3.],
    39         [5., 6.]]))
    View Code


     

  • 相关阅读:
    CSU 1605 数独
    HDU 1426 dancing links解决数独问题
    FZU 1686 dlx重复覆盖
    hdu 2295 dlx重复覆盖+二分答案
    zju 3209 dancing links 求取最小行数
    hust 1017 dancing links 精确覆盖模板题
    POJ 1724 二维费用最短路
    【转载】学习总结:初等数论(3)——原根、指标及其应用
    【poj3415-Common Substrings】sam子串计数
    【hdu4436/LA6387-str2int】sam处理不同子串
  • 原文地址:https://www.cnblogs.com/zhaopengpeng/p/13597802.html
Copyright © 2011-2022 走看看