zoukankan      html  css  js  c++  java
  • pytorch 数据拼接与拆分cat、stack、split、chunck

    1、cat拼接

    • 功能:通过dim指定维度,在当前指定维度上直接拼接
    • 默认是dim=0
    • 指定的dim上,维度可以不相同,其他dim上维度必须相同,不然会报错。

    1)拼接两个维度相同的数

    a = torch.rand(2, 3, 2)
    a
    # 输出:
        tensor([[[0.6072, 0.6531],
                 [0.2023, 0.2506],
                 [0.0590, 0.3390]],
    
                [[0.3994, 0.0110],
                 [0.3615, 0.3826],
                 [0.3033, 0.3096]]])
        
    b = torch.rand(2, 3, 2)  # 定义b与a大小相同
    b
    # 输出:
        tensor([[[0.6144, 0.4561],
                 [0.9263, 0.0644],
                 [0.2838, 0.3456]],
    
                [[0.1126, 0.5303],
                 [0.8140, 0.5715],
                 [0.7627, 0.5095]]])   
    
        
    # dim选定合并的维度
    torch.cat([a, b])  # 不指定dim时,默认是0
    # 输出:
        tensor([[[0.6072, 0.6531],
                 [0.2023, 0.2506],
                 [0.0590, 0.3390]],
    
                [[0.3994, 0.0110],
                 [0.3615, 0.3826],
                 [0.3033, 0.3096]],
    
                [[0.6144, 0.4561],
                 [0.9263, 0.0644],
                 [0.2838, 0.3456]],
    
                [[0.1126, 0.5303],
                 [0.8140, 0.5715],
                 [0.7627, 0.5095]]])
    
    # 选定合并的维度dim=0
    torch.cat([a, b], dim=0)  # 指定dim=0,可以看到结果和上面的是一样的
    # 输出:   
        tensor([[[0.6072, 0.6531],
                 [0.2023, 0.2506],
                 [0.0590, 0.3390]],
    
                [[0.3994, 0.0110],
                 [0.3615, 0.3826],
                 [0.3033, 0.3096]],
    
                [[0.6144, 0.4561],
                 [0.9263, 0.0644],
                 [0.2838, 0.3456]],
    
                [[0.1126, 0.5303],
                 [0.8140, 0.5715],
                 [0.7627, 0.5095]]])
        
    # 选定合并的维度dim=1
    torch.cat([a, b], dim=1)
    # 输出:
        tensor([[[0.6072, 0.6531],
                 [0.2023, 0.2506],
                 [0.0590, 0.3390],
                 [0.6144, 0.4561],
                 [0.9263, 0.0644],
                 [0.2838, 0.3456]],
    
                [[0.3994, 0.0110],
                 [0.3615, 0.3826],
                 [0.3033, 0.3096],
                 [0.1126, 0.5303],
                 [0.8140, 0.5715],
                 [0.7627, 0.5095]]])    
        
    # 选定合并的维度dim=2
    torch.cat([a, b], dim=2)
    # 输出:
        tensor([[[0.6072, 0.6531, 0.6144, 0.4561],
                 [0.2023, 0.2506, 0.9263, 0.0644],
                 [0.0590, 0.3390, 0.2838, 0.3456]],
    
                [[0.3994, 0.0110, 0.1126, 0.5303],
                 [0.3615, 0.3826, 0.8140, 0.5715],
                 [0.3033, 0.3096, 0.7627, 0.5095]]])    
    

    2)拼接两个维度不同的数

    结合上面维度相同的数对比,便于理解

    a = torch.rand(2, 3, 2)
    a
    # 输出:
        tensor([[[0.6447, 0.9758],
                 [0.0688, 0.9082],
                 [0.0083, 0.0109]],
    
                [[0.5239, 0.1217],
                 [0.9562, 0.6831],
                 [0.8691, 0.2769]]])
        
    b = torch.rand(2, 2, 2)
    b
    # 输出:   
    tensor([[[0.3604, 0.7585],
             [0.7831, 0.0439]],
    
            [[0.2040, 0.5002],
             [0.8878, 0.5973]]])
    
    # 不指定dim:
    torch.cat([a, b])
    # 因为dim默认是0,且a,b的dim[1]的大小不等(a是3, b是2),所以导致会报错
    # 输出:   
        ---------------------------------------------------------------------------
        RuntimeError                              Traceback (most recent call last)
        <ipython-input-32-5484713fecdf> in <module>
        ----> 1 torch.cat([a, b])
              2 # 输出:
    
        RuntimeError: inv
            
    # 可以看到,此时的ab,因为只有dim[1]不同,所以如果要用cat合并,只能在dim=1上合并       
    torch.cat([a, b], dim=1)
    # 输出:   
        tensor([[[0.6447, 0.9758],
                 [0.0688, 0.9082],
                 [0.0083, 0.0109],
                 [0.3604, 0.7585],
                 [0.7831, 0.0439]],
    
                [[0.5239, 0.1217],
                 [0.9562, 0.6831],
                 [0.8691, 0.2769],
                 [0.2040, 0.5002],
                 [0.8878, 0.5973]]])        
    

    2.stack拼接

    • 与cat不同的是,stack是在拼接的同时,在指定dim处插入维度后拼接。
    • 可以理解为:stack是在指定维度处,分别为两个维度数据加上一层[]后,再进行拼接。
    • 对比cat会发现,cat的相同维度的两部分数据是在一个[]里面,而stack的两部分数据分别是在2个[]里面
    • stack拼接的两个数据,其所有维度必须相同
    • 默认dim=0
    a = torch.rand(2, 5)
    a
    # 输出:
        tensor([[0.2214, 0.2666, 0.6486, 0.7050, 0.4259],
                [0.6929, 0.4945, 0.0631, 0.4546, 0.6918]])
    
    b = torch.rand(2, 5)
    b
    # 输出:
        tensor([[0.7893, 0.4141, 0.2971, 0.6791, 0.9791],
                [0.4722, 0.7540, 0.5282, 0.0625, 0.0448]])
        
    # 默认dim=0。将两个数据直接拼接
    torch.stack([a, b])
    # 输出:
        tensor([[[0.2214, 0.2666, 0.6486, 0.7050, 0.4259],
                 [0.6929, 0.4945, 0.0631, 0.4546, 0.6918]],
    
                [[0.7893, 0.4141, 0.2971, 0.6791, 0.9791],
                 [0.4722, 0.7540, 0.5282, 0.0625, 0.0448]]])    
        
    # 指定dim=0
    # 此处可以对比cat拼接,发现同样是dim=0,cat的数据在一个[]里面。此处是数据被分成了2段(在两个[]里面)
    torch.stack([a, b], dim=0)  # 可以看到和上面默认的结果一致
    # 输出:
        tensor([[[0.2214, 0.2666, 0.6486, 0.7050, 0.4259],
                 [0.6929, 0.4945, 0.0631, 0.4546, 0.6918]],
    
                [[0.7893, 0.4141, 0.2971, 0.6791, 0.9791],
                 [0.4722, 0.7540, 0.5282, 0.0625, 0.0448]]])    
        
    # 指定dim=1。将数据在dim=1维度上拼接。
    # 注意:结果后上面dim=0有区别。
    torch.stack([a, b], dim=1)
    # 输出:
        tensor([[[0.2214, 0.2666, 0.6486, 0.7050, 0.4259],
                 [0.7893, 0.4141, 0.2971, 0.6791, 0.9791]],
    
                [[0.6929, 0.4945, 0.0631, 0.4546, 0.6918],
                 [0.4722, 0.7540, 0.5282, 0.0625, 0.0448]]])    
        
    # 指定dim=2。将数据在dim=2维度上拼接。
    torch.stack([a, b], dim=2)
    # 输出:
        tensor([[[0.2214, 0.7893],
                 [0.2666, 0.4141],
                 [0.6486, 0.2971],
                 [0.7050, 0.6791],
                 [0.4259, 0.9791]],
    
                [[0.6929, 0.4722],
                 [0.4945, 0.7540],
                 [0.0631, 0.5282],
                 [0.4546, 0.0625],
                 [0.6918, 0.0448]]])    
    

    3、split拆分

    • 指定拆分dim
    • 给定拆分后的数据大小
    a = torch.rand(4, 3, 2)
    a
    # 输出:
        tensor([[[0.5790, 0.6024],
                 [0.4730, 0.0734],
                 [0.2274, 0.7212]],
    
                [[0.7051, 0.1568],
                 [0.5890, 0.1075],
                 [0.7469, 0.0659]],
    
                [[0.7780, 0.5424],
                 [0.4344, 0.8551],
                 [0.6729, 0.7372]],
    
                [[0.1669, 0.8596],
                 [0.9490, 0.8378],
                 [0.7889, 0.2192]]])
    
        
    # 默认情况下dim=0    
    # 因为dim=0的大小是4,所以拆分为2 + 2 = 4,或者1+3=4.。。均可
    a.split([2, 2])  
    # 输出:
        (tensor([[[0.5790, 0.6024],
                  [0.4730, 0.0734],
                  [0.2274, 0.7212]],
    
                 [[0.7051, 0.1568],
                  [0.5890, 0.1075],
                  [0.7469, 0.0659]]]),
         tensor([[[0.7780, 0.5424],
                  [0.4344, 0.8551],
                  [0.6729, 0.7372]],
    
                 [[0.1669, 0.8596],
                  [0.9490, 0.8378],
                  [0.7889, 0.2192]]]))
    
        
    # 因为dim=1的大小是3,所以拆分为2 + 1 = 3
    a.split([2, 1], dim=1)
    # 输出:
        (tensor([[[0.5790, 0.6024],
                  [0.4730, 0.0734]],
    
                 [[0.7051, 0.1568],
                  [0.5890, 0.1075]],
    
                 [[0.7780, 0.5424],
                  [0.4344, 0.8551]],
    
                 [[0.1669, 0.8596],
                  [0.9490, 0.8378]]]),
         tensor([[[0.2274, 0.7212]],
    
                 [[0.7469, 0.0659]],
    
                 [[0.6729, 0.7372]],
    
                 [[0.7889, 0.2192]]]))    
    
    
    # 因为dim=2的大小是2,所以拆分为1 + 1 = 2
    a.split([1, 1], dim=2)
    # 输出:
    (tensor([[[0.5790],
              [0.4730],
              [0.2274]],
     
             [[0.7051],
              [0.5890],
              [0.7469]],
     
             [[0.7780],
              [0.4344],
              [0.6729]],
     
             [[0.1669],
              [0.9490],
              [0.7889]]]),
     tensor([[[0.6024],
              [0.0734],
              [0.7212]],
     
             [[0.1568],
              [0.1075],
              [0.0659]],
     
             [[0.5424],
              [0.8551],
              [0.7372]],
     
             [[0.8596],
              [0.8378],
              [0.2192]]]))    
    

    chunk拆分

    • chunk是在指定dim下给定,平均拆分的个数
    • 如果给定个数不能平均拆分当前维度,则会取比给定个数小的,能平均拆分数据的,最大的个数
    • dim默认是0
    a
    # 输出:
        tensor([[[0.5790, 0.6024],
                 [0.4730, 0.0734],
                 [0.2274, 0.7212]],
    
                [[0.7051, 0.1568],
                 [0.5890, 0.1075],
                 [0.7469, 0.0659]],
    
                [[0.7780, 0.5424],
                 [0.4344, 0.8551],
                 [0.6729, 0.7372]],
    
                [[0.1669, 0.8596],
                 [0.9490, 0.8378],
                 [0.7889, 0.2192]]])
        
    # 默认dim=0
    # 在dim=0上,将数据平均分成4份
    a.chunk(4)
    # 输出:
        (tensor([[[0.5790, 0.6024],
                  [0.4730, 0.0734],
                  [0.2274, 0.7212]]]),
         tensor([[[0.7051, 0.1568],
                  [0.5890, 0.1075],
                  [0.7469, 0.0659]]]),
         tensor([[[0.7780, 0.5424],
                  [0.4344, 0.8551],
                  [0.6729, 0.7372]]]),
         tensor([[[0.1669, 0.8596],
                  [0.9490, 0.8378],
                  [0.7889, 0.2192]]]))    
        
    # 在dim=0上,将数据平均分成4份
    # 因为4不能被3整除,且比3小,能把4整除的数是2。所以,虽然给定是3,其实得到的结果为2个部分。
    a.chunk(3, dim=0)
    # 输出:
        (tensor([[[0.5790, 0.6024],
                  [0.4730, 0.0734],
                  [0.2274, 0.7212]],
    
                 [[0.7051, 0.1568],
                  [0.5890, 0.1075],
                  [0.7469, 0.0659]]]),
         tensor([[[0.7780, 0.5424],
                  [0.4344, 0.8551],
                  [0.6729, 0.7372]],
    
                 [[0.1669, 0.8596],
                  [0.9490, 0.8378],
                  [0.7889, 0.2192]]]))    
        
    # 在dim=1上,将数据平均分成3份
    a.chunk(3, dim=1)
    # 输出:
        (tensor([[[0.5790, 0.6024]],
    
                 [[0.7051, 0.1568]],
    
                 [[0.7780, 0.5424]],
    
                 [[0.1669, 0.8596]]]),
         tensor([[[0.4730, 0.0734]],
    
                 [[0.5890, 0.1075]],
    
                 [[0.4344, 0.8551]],
    
                 [[0.9490, 0.8378]]]),
         tensor([[[0.2274, 0.7212]],
    
                 [[0.7469, 0.0659]],
    
                 [[0.6729, 0.7372]],
    
                 [[0.7889, 0.2192]]]))    
        
    # 在dim=2上,将数据平均分成3份
    a.chunk(2, dim=2)
    # 输出:
        (tensor([[[0.5790],
                  [0.4730],
                  [0.2274]],
    
                 [[0.7051],
                  [0.5890],
                  [0.7469]],
    
                 [[0.7780],
                  [0.4344],
                  [0.6729]],
    
                 [[0.1669],
                  [0.9490],
                  [0.7889]]]),
         tensor([[[0.6024],
                  [0.0734],
                  [0.7212]],
    
                 [[0.1568],
                  [0.1075],
                  [0.0659]],
    
                 [[0.5424],
                  [0.8551],
                  [0.7372]],
    
                 [[0.8596],
                  [0.8378],
                  [0.2192]]]))    
    
  • 相关阅读:
    sql删除重复数据,保留一条
    sql列转行
    异步线程:一次性发送大量邮件
    限制接口的访问次数
    Kibana(安装及其简单crud)
    Elasticsearch(简介及其单节点搭建)
    大数据(日志分析)项目
    大数据(sqoop)
    大数据(Hive数据库、表的详解及其Hive数据导入导出)
    大数据(Hive的MetaStore切换及其Hive的语法细节)
  • 原文地址:https://www.cnblogs.com/jaysonteng/p/13038080.html
Copyright © 2011-2022 走看看