zoukankan      html  css  js  c++  java
  • torch.cat()和torch.stack()

    torch.cat() 和 torch.stack()略有不同
    torch.cat(tensors,dim=0,out=None)→ Tensor
    torch.cat()对tensors沿指定维度拼接,但返回的Tensor的维数不会变,可理解为续接;
    torch.stack(tensors,dim=0,out=None)→ Tensor
    torch.stack()同样是对tensors沿指定维度拼接,但返回的Tensor会多一维,可理解为叠加;
    ————————————————
    版权声明:本文为CSDN博主「进阶媛小吴」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
    原文链接:https://blog.csdn.net/wuli_xin/article/details/118972316

     

    结果为:

     上述行数相同d,c,在第一维度也即列上拼接时,能拼接成100行六列的tensor.

    import torch
    # t1=torch.tensor([1,1,1])
    # t2=torch.tensor([2,2,2])
    # t3=torch.tensor([3,3,3])
    
    f1=torch.tensor([[1,2,3],[4,5,6]])
    f2=torch.tensor([[7,8,9],[10,11,12]])
    c=torch.tensor([[13,14,15],[16,17,18]])
    
    # a=torch.cat((f1,f2,c),dim=0)
    # b=torch.cat((f1,f2,c),1)
    # print(a.shape,b.shape,sep='
    ')
    # d=torch.rand(100,4)
    # e=torch.cat((d,c),1)
    # print(e.shape)
    g=torch.stack((f1,f2,c),0)
    g1=torch.stack((f1,f2,c),1)
    print(f1.shape,f2.shape,c.shape)
    print('g: ',g.shape,g,sep='
    ')
    print('g1: ',g1.shape,g1,sep='
    ')
    
    输出结果为:
    
    torch.Size([2, 3]) torch.Size([2, 3]) torch.Size([2, 3])
    g: 
    torch.Size([3, 2, 3])#本来3个,就3个;本来2行3列就两行三列;只不过把他们放到一起,变成了3维的,多了一个维度;个人理解,可能有误。
    tensor([[[ 1,  2,  3],
             [ 4,  5,  6]],
    
            [[ 7,  8,  9],
             [10, 11, 12]],
    
            [[13, 14, 15],
             [16, 17, 18]]])
    g1: 
    torch.Size([2, 3, 3])#把本来的三个中,每个的第一列拼在一块;第二列拼在一块;再把拼过后的第一列和第二列分别作为一个二维矩阵; 个人理解,可能有误。
    
    tensor([[[ 1,  2,  3],
             [ 7,  8,  9],
             [13, 14, 15]],
    
            [[ 4,  5,  6],
             [10, 11, 12],
             [16, 17, 18]]])
    
    import torch
    # t1=torch.tensor([1,1,1])
    # t2=torch.tensor([2,2,2])
    # t3=torch.tensor([3,3,3])
    
    f1=torch.tensor([[1,2,3],[4,5,6]])
    f2=torch.tensor([[7,8,9],[10,11,12]])
    c=torch.tensor([[13,14,15],[16,17,18]])
    
    # a=torch.cat((f1,f2,c),dim=0)
    # b=torch.cat((f1,f2,c),1)
    # print(a.shape,b.shape,sep='
    ')
    # d=torch.rand(100,4)
    # e=torch.cat((d,c),1)
    # print(e.shape)
    g=torch.stack((f1,f2,c),0)
    g1=torch.stack((f1,f2,c),2)
    print(f1.shape,f2.shape,c.shape)
    print('g: ',g.shape,g,sep='
    ')
    print('g1: ',g1.shape,g1,sep='
    ')
    
    输出结果:
    torch.Size([2, 3]) torch.Size([2, 3]) torch.Size([2, 3])
    g: 
    torch.Size([3, 2, 3])
    tensor([[[ 1,  2,  3],
             [ 4,  5,  6]],
            [[ 7,  8,  9],
             [10, 11, 12]],
            [[13, 14, 15],
             [16, 17, 18]]])
    g1: 
    torch.Size([2, 3, 3])
    tensor([[[ 1,  7, 13],
             [ 2,  8, 14],
             [ 3,  9, 15]],
            [[ 4, 10, 16],
             [ 5, 11, 17],
             [ 6, 12, 18]]])
    

      

    import torch
    # t1=torch.tensor([1,1,1])
    # t2=torch.tensor([2,2,2])
    # t3=torch.tensor([3,3,3])
    
    f1=torch.tensor([[1,2,3],[4,5,6]])
    f2=torch.tensor([[7,8,9],[10,11,12]])
    c=torch.tensor([[13,14,15],[16,17,18]])
    
    # a=torch.cat((f1,f2,c),dim=0)
    # b=torch.cat((f1,f2,c),1)
    # print(a.shape,b.shape,sep='
    ')
    # d=torch.rand(100,4)
    # e=torch.cat((d,c),1)
    # print(e.shape)
    g=torch.stack((f1,f2,c),0)
    g1=torch.stack((f1,f2,c),3)#此处dim=3,或比3大的任何正数,都是如下报错结果。
    print(f1.shape,f2.shape,c.shape)
    print('g: ',g.shape,g,sep='
    ')
    print('g1: ',g1.shape,g1,sep='
    ')
    
    输出结果:
    Traceback (most recent call last):
      File "<input>", line 17, in <module>
    IndexError: Dimension out of range (expected to be in range of [-3, 2], but got 3)
    

    此外: 

    如果,torch.stack()的维度dim输入的是-1,-2,-3,也都可以正确输出结果。但是如果输入比-3小的任何数则会报错;具体如下:

    import torch
    # t1=torch.tensor([1,1,1])
    # t2=torch.tensor([2,2,2])
    # t3=torch.tensor([3,3,3])
    
    f1=torch.tensor([[1,2,3],[4,5,6]])
    f2=torch.tensor([[7,8,9],[10,11,12]])
    c=torch.tensor([[13,14,15],[16,17,18]])
    
    # a=torch.cat((f1,f2,c),dim=0)
    # b=torch.cat((f1,f2,c),1)
    # print(a.shape,b.shape,sep='
    ')
    # d=torch.rand(100,4)
    # e=torch.cat((d,c),1)
    # print(e.shape)
    g=torch.stack((f1,f2,c),0)
    g1=torch.stack((f1,f2,c),-1) #此时,维度是-1
    print(f1.shape,f2.shape,c.shape)
    print('g: ',g.shape,g,sep='
    ')
    print('g1: ',g1.shape,g1,sep='
    ')
    
    输出结果为:
    torch.Size([2, 3]) torch.Size([2, 3]) torch.Size([2, 3])
    g: 
    torch.Size([3, 2, 3])
    tensor([[[ 1,  2,  3],
             [ 4,  5,  6]],
    
            [[ 7,  8,  9],
             [10, 11, 12]],
    
            [[13, 14, 15],
             [16, 17, 18]]])
    g1: 
    torch.Size([2, 3, 3])
    tensor([[[ 1,  7, 13],
             [ 2,  8, 14],
             [ 3,  9, 15]],
    
            [[ 4, 10, 16],
             [ 5, 11, 17],
             [ 6, 12, 18]]])
    

    torch.stack()的维度dim输入的是--2;

    输出结果为:
    torch.Size([2, 3]) torch.Size([2, 3]) torch.Size([2, 3])
    g: 
    torch.Size([3, 2, 3])
    tensor([[[ 1,  2,  3],
             [ 4,  5,  6]],
    
            [[ 7,  8,  9],
             [10, 11, 12]],
    
            [[13, 14, 15],
             [16, 17, 18]]])
    g1: 
    torch.Size([2, 3, 3])
    tensor([[[ 1,  2,  3],
             [ 7,  8,  9],
             [13, 14, 15]],
    
            [[ 4,  5,  6],
             [10, 11, 12],
             [16, 17, 18]]])
    

    torch.stack()的维度dim输入的是-3;

    torch.Size([2, 3]) torch.Size([2, 3]) torch.Size([2, 3])
    g: 
    torch.Size([3, 2, 3])
    tensor([[[ 1,  2,  3],
             [ 4,  5,  6]],
    
            [[ 7,  8,  9],
             [10, 11, 12]],
    
            [[13, 14, 15],
             [16, 17, 18]]])
    g1: 
    torch.Size([3, 2, 3])
    tensor([[[ 1,  2,  3],
             [ 4,  5,  6]],
    
            [[ 7,  8,  9],
             [10, 11, 12]],
    
            [[13, 14, 15],
             [16, 17, 18]]])
    
    import torch
    # t1=torch.tensor([1,1,1])
    # t2=torch.tensor([2,2,2])
    # t3=torch.tensor([3,3,3])
    
    f1=torch.tensor([[1,2,3],[4,5,6]])
    f2=torch.tensor([[7,8,9],[10,11,12]])
    c=torch.tensor([[13,14,15],[16,17,18]])
    
    # a=torch.cat((f1,f2,c),dim=0)
    # b=torch.cat((f1,f2,c),1)
    # print(a.shape,b.shape,sep='
    ')
    # d=torch.rand(100,4)
    # e=torch.cat((d,c),1)
    # print(e.shape)
    g=torch.stack((f1,f2,c),0)
    g1=torch.stack((f1,f2,c),-4)#此处是dim=-4,小于-4的任何负数,输出类似的结果。
    print(f1.shape,f2.shape,c.shape)
    print('g: ',g.shape,g,sep='
    ')
    print('g1: ',g1.shape,g1,sep='
    ')
    
    输出结果为:
    Traceback (most recent call last):
      File "<input>", line 17, in <module>
    IndexError: Dimension out of range (expected to be in range of [-3, 2], but got -4)
    

      

  • 相关阅读:
    常用品牌交换机镜像抓包配置
    BGP知识点备忘录
    IS-IS路由协议地址详解
    Linux msmtp+mutt发邮件
    Linux添加一临时用户拥有root权限最快方式
    ELK5.0全程普通用户源码安装指南(CentOS6.5)
    改变文件的拥有者和改变文件的拥有组
    Linux chmod命令详解
    Linux目录介绍
    php时间戳转化成时间相差8小时问题
  • 原文地址:https://www.cnblogs.com/Li-JT/p/15165425.html
Copyright © 2011-2022 走看看