zoukankan      html  css  js  c++  java
  • pytorch tensor的索引与切片

    tensor索引与numpy类似,支持冒号,和数字直接索引

    import torch
    
    a = torch.Tensor(2, 3, 4)
    a
    # 输出:
          tensor([[[9.2755e-39, 1.0561e-38, 9.7347e-39, 1.1112e-38],
                 [1.0194e-38, 8.4490e-39, 1.0102e-38, 9.0919e-39],
                 [1.0102e-38, 8.9082e-39, 8.4489e-39, 1.0102e-38]],
        
                [[1.0561e-38, 1.0286e-38, 1.0653e-38, 1.0469e-38],
                 [9.5510e-39, 9.9184e-39, 9.0000e-39, 1.0561e-38],
                 [1.0653e-38, 4.1327e-39, 8.9082e-39, 9.8265e-39]]])
    
    # 冒号索引与数字索引
    a[:1, :2, 1]
    # 输出:
          tensor([[1.0561e-38, 8.4490e-39]])
    
    # 通过-1索引
    a[-1]
    # 输出:
          tensor([[1.0561e-38, 1.0286e-38, 1.0653e-38, 1.0469e-38],
                [9.5510e-39, 9.9184e-39, 9.0000e-39, 1.0561e-38],
                [1.0653e-38, 4.1327e-39, 8.9082e-39, 9.8265e-39]])
    

    ...(三个点)索引

    用于维度过多,且取中间多个维度所有数据的情况

    # 生成多维数据
    a = torch.rand(1,2,3,2,4,5)
    a
    # 输出:
         tensor([[[[[[0.1954, 0.1918, 0.3053, 0.3649, 0.3637],
                    [0.8467, 0.0205, 0.2187, 0.8438, 0.1754],
                    [0.7076, 0.7047, 0.1852, 0.5374, 0.7024],
                    [0.5630, 0.4526, 0.0662, 0.9463, 0.9294]],
        
                   [[0.6917, 0.5505, 0.5770, 0.3819, 0.9541],
                    [0.8957, 0.2530, 0.4858, 0.1866, 0.2542],
                    [0.3745, 0.2125, 0.5537, 0.5642, 0.2284],
                    [0.2634, 0.1147, 0.1793, 0.0277, 0.9800]]], 
    
                  ...
    
                  [[[0.9949, 0.2210, 0.3365, 0.0852, 0.4387],
                    [0.6440, 0.6391, 0.9141, 0.2288, 0.6203],
                    [0.0474, 0.7894, 0.4362, 0.9752, 0.7546],
                    [0.1234, 0.0246, 0.1436, 0.0053, 0.3405]],
        
                   [[0.8174, 0.9021, 0.0420, 0.2045, 0.2140],
                    [0.4844, 0.6342, 0.2965, 0.9299, 0.2284],
                    [0.1420, 0.1834, 0.0581, 0.8467, 0.8987],
                    [0.8012, 0.1526, 0.4293, 0.3928, 0.5437]]]]]]) 
    
    # 取第一维和最后一维的0索引数据,中间所有维度数据全部取出
    a[0, ..., 0]
    # 输出:
          tensor([[[[0.1954, 0.8467, 0.7076, 0.5630],
                  [0.6917, 0.8957, 0.3745, 0.2634]],
        
                 [[0.4374, 0.0534, 0.6809, 0.7086],
                  [0.2231, 0.6680, 0.8643, 0.9057]],
        
                 [[0.8169, 0.0649, 0.5923, 0.3802],
                  [0.2562, 0.0095, 0.8557, 0.6828]]],
        
        
                [[[0.1514, 0.3948, 0.6452, 0.6332],
                  [0.8872, 0.7304, 0.6853, 0.9814]],
        
                 [[0.5736, 0.5195, 0.9711, 0.5575],
                  [0.6778, 0.9334, 0.5647, 0.1006]],
        
                 [[0.9949, 0.6440, 0.0474, 0.1234],
                  [0.8174, 0.4844, 0.1420, 0.8012]]]])
    
    # 上面等价于
    a[0,:,:,:,:,0]
    # 输出:
          tensor([[[[0.1954, 0.8467, 0.7076, 0.5630],
                  [0.6917, 0.8957, 0.3745, 0.2634]],
        
                 [[0.4374, 0.0534, 0.6809, 0.7086],
                  [0.2231, 0.6680, 0.8643, 0.9057]],
        
                 [[0.8169, 0.0649, 0.5923, 0.3802],
                  [0.2562, 0.0095, 0.8557, 0.6828]]],
        
        
                [[[0.1514, 0.3948, 0.6452, 0.6332],
                  [0.8872, 0.7304, 0.6853, 0.9814]],
        
                 [[0.5736, 0.5195, 0.9711, 0.5575],
                  [0.6778, 0.9334, 0.5647, 0.1006]],
        
                 [[0.9949, 0.6440, 0.0474, 0.1234],
                  [0.8174, 0.4844, 0.1420, 0.8012]]]])
    可以看出,使用...可以节省操作。
    

    masked_select

    # 生成随机数据
    a = torch.randn(3, 4)
    a
    # 输出:
        tensor([[ 0.8710,  0.8862, -0.4620, -0.9985],
                [ 0.4734, -0.7182, -0.1516,  0.0209],
                [ 0.5089, -0.8130, -0.4519, -0.6190]])
    
    # 大于0.5的数据返回True
    mask = a.ge(0.5)
    mask
    # 输出:
        tensor([[ True,  True, False, False],
                [False, False, False, False],
                [ True, False, False, False]])
    
    # 通过上面生成的bool数据,利用masked_select来选择大于0.5的数据
    torch.masked_select(a, mask)
    # 输出:
        tensor([0.8710, 0.8862, 0.5089])  
    

    take

    a
    # 输出:
          tensor([[ 0.8710,  0.8862, -0.4620, -0.9985],
                [ 0.4734, -0.7182, -0.1516,  0.0209],
                [ 0.5089, -0.8130, -0.4519, -0.6190]])
    
    # 先将数据打平展开为一维,再选取展开后对应索引[0, 5, 8, 11]的数据
    torch.take(a, torch.tensor([0, 5, 8, 11]))
    # 输出:
          tensor([ 0.8710, -0.7182,  0.5089, -0.6190])
    
  • 相关阅读:
    Allegro绘制PCB流程
    KSImageNamed-Xcode
    UIApplicationsharedApplication的常用使用方法
    javascript中间AJAX
    hdu1845 Jimmy’s Assignment --- 完整匹配
    嵌入式控制系统和计算机系统
    Bean行为破坏之前,
    jsonkit 分解nsarray 时刻 一个错误
    IO 字符流学习
    2013级别C++文章9周(春天的)工程——运算符重载(两)
  • 原文地址:https://www.cnblogs.com/jaysonteng/p/13019514.html
Copyright © 2011-2022 走看看