zoukankan      html  css  js  c++  java
  • pytorch tensor的索引与切片

    tensor索引与numpy类似,支持冒号,和数字直接索引

    import torch
    
    a = torch.Tensor(2, 3, 4)
    a
    # 输出:
          tensor([[[9.2755e-39, 1.0561e-38, 9.7347e-39, 1.1112e-38],
                 [1.0194e-38, 8.4490e-39, 1.0102e-38, 9.0919e-39],
                 [1.0102e-38, 8.9082e-39, 8.4489e-39, 1.0102e-38]],
        
                [[1.0561e-38, 1.0286e-38, 1.0653e-38, 1.0469e-38],
                 [9.5510e-39, 9.9184e-39, 9.0000e-39, 1.0561e-38],
                 [1.0653e-38, 4.1327e-39, 8.9082e-39, 9.8265e-39]]])
    
    # 冒号索引与数字索引
    a[:1, :2, 1]
    # 输出:
          tensor([[1.0561e-38, 8.4490e-39]])
    
    # 通过-1索引
    a[-1]
    # 输出:
          tensor([[1.0561e-38, 1.0286e-38, 1.0653e-38, 1.0469e-38],
                [9.5510e-39, 9.9184e-39, 9.0000e-39, 1.0561e-38],
                [1.0653e-38, 4.1327e-39, 8.9082e-39, 9.8265e-39]])
    

    ...(三个点)索引

    用于维度过多,且取中间多个维度所有数据的情况

    # 生成多维数据
    a = torch.rand(1,2,3,2,4,5)
    a
    # 输出:
         tensor([[[[[[0.1954, 0.1918, 0.3053, 0.3649, 0.3637],
                    [0.8467, 0.0205, 0.2187, 0.8438, 0.1754],
                    [0.7076, 0.7047, 0.1852, 0.5374, 0.7024],
                    [0.5630, 0.4526, 0.0662, 0.9463, 0.9294]],
        
                   [[0.6917, 0.5505, 0.5770, 0.3819, 0.9541],
                    [0.8957, 0.2530, 0.4858, 0.1866, 0.2542],
                    [0.3745, 0.2125, 0.5537, 0.5642, 0.2284],
                    [0.2634, 0.1147, 0.1793, 0.0277, 0.9800]]], 
    
                  ...
    
                  [[[0.9949, 0.2210, 0.3365, 0.0852, 0.4387],
                    [0.6440, 0.6391, 0.9141, 0.2288, 0.6203],
                    [0.0474, 0.7894, 0.4362, 0.9752, 0.7546],
                    [0.1234, 0.0246, 0.1436, 0.0053, 0.3405]],
        
                   [[0.8174, 0.9021, 0.0420, 0.2045, 0.2140],
                    [0.4844, 0.6342, 0.2965, 0.9299, 0.2284],
                    [0.1420, 0.1834, 0.0581, 0.8467, 0.8987],
                    [0.8012, 0.1526, 0.4293, 0.3928, 0.5437]]]]]]) 
    
    # 取第一维和最后一维的0索引数据,中间所有维度数据全部取出
    a[0, ..., 0]
    # 输出:
          tensor([[[[0.1954, 0.8467, 0.7076, 0.5630],
                  [0.6917, 0.8957, 0.3745, 0.2634]],
        
                 [[0.4374, 0.0534, 0.6809, 0.7086],
                  [0.2231, 0.6680, 0.8643, 0.9057]],
        
                 [[0.8169, 0.0649, 0.5923, 0.3802],
                  [0.2562, 0.0095, 0.8557, 0.6828]]],
        
        
                [[[0.1514, 0.3948, 0.6452, 0.6332],
                  [0.8872, 0.7304, 0.6853, 0.9814]],
        
                 [[0.5736, 0.5195, 0.9711, 0.5575],
                  [0.6778, 0.9334, 0.5647, 0.1006]],
        
                 [[0.9949, 0.6440, 0.0474, 0.1234],
                  [0.8174, 0.4844, 0.1420, 0.8012]]]])
    
    # 上面等价于
    a[0,:,:,:,:,0]
    # 输出:
          tensor([[[[0.1954, 0.8467, 0.7076, 0.5630],
                  [0.6917, 0.8957, 0.3745, 0.2634]],
        
                 [[0.4374, 0.0534, 0.6809, 0.7086],
                  [0.2231, 0.6680, 0.8643, 0.9057]],
        
                 [[0.8169, 0.0649, 0.5923, 0.3802],
                  [0.2562, 0.0095, 0.8557, 0.6828]]],
        
        
                [[[0.1514, 0.3948, 0.6452, 0.6332],
                  [0.8872, 0.7304, 0.6853, 0.9814]],
        
                 [[0.5736, 0.5195, 0.9711, 0.5575],
                  [0.6778, 0.9334, 0.5647, 0.1006]],
        
                 [[0.9949, 0.6440, 0.0474, 0.1234],
                  [0.8174, 0.4844, 0.1420, 0.8012]]]])
    可以看出,使用...可以节省操作。
    

    masked_select

    # 生成随机数据
    a = torch.randn(3, 4)
    a
    # 输出:
        tensor([[ 0.8710,  0.8862, -0.4620, -0.9985],
                [ 0.4734, -0.7182, -0.1516,  0.0209],
                [ 0.5089, -0.8130, -0.4519, -0.6190]])
    
    # 大于0.5的数据返回True
    mask = a.ge(0.5)
    mask
    # 输出:
        tensor([[ True,  True, False, False],
                [False, False, False, False],
                [ True, False, False, False]])
    
    # 通过上面生成的bool数据,利用masked_select来选择大于0.5的数据
    torch.masked_select(a, mask)
    # 输出:
        tensor([0.8710, 0.8862, 0.5089])  
    

    take

    a
    # 输出:
          tensor([[ 0.8710,  0.8862, -0.4620, -0.9985],
                [ 0.4734, -0.7182, -0.1516,  0.0209],
                [ 0.5089, -0.8130, -0.4519, -0.6190]])
    
    # 先将数据打平展开为一维,再选取展开后对应索引[0, 5, 8, 11]的数据
    torch.take(a, torch.tensor([0, 5, 8, 11]))
    # 输出:
          tensor([ 0.8710, -0.7182,  0.5089, -0.6190])
    
  • 相关阅读:
    nginx-1.8.1的安装
    ElasticSearch 在3节点集群的启动
    The type java.lang.CharSequence cannot be resolved. It is indirectly referenced from required .class files
    sqoop导入导出对mysql再带数据库test能跑通用户自己建立的数据库则不行
    LeetCode 501. Find Mode in Binary Search Tree (找到二叉搜索树的众数)
    LeetCode 437. Path Sum III (路径之和之三)
    LeetCode 404. Sum of Left Leaves (左子叶之和)
    LeetCode 257. Binary Tree Paths (二叉树路径)
    LeetCode Questions List (LeetCode 问题列表)- Java Solutions
    LeetCode 561. Array Partition I (数组分隔之一)
  • 原文地址:https://www.cnblogs.com/jaysonteng/p/13019514.html
Copyright © 2011-2022 走看看