zoukankan      html  css  js  c++  java
  • Pytorch中的高级选择函数

      参考资料:

      https://tangshusen.me/Dive-into-DL-PyTorch/#/chapter02_prerequisite/2.2_tensor?id=_222-%e6%93%8d%e4%bd%9c

      https://pytorch.org/docs/stable/index.html

      深度学习里,很多时候我们只想取输出中的一部分值,此时便用上了Pytorch中的高级索引函数。我们见过最多的可能就是torch.gather这个函数了。这个随笔讲解一下Pytorch中的高级选择函数。

      一、torch.index_select

    torch.index_select(input, dim, index, *, out=None)-> Tensor
    """
    :param input(Tensor) - the input tensor
    :param dim(int) - the dimension in which we index
    :param index(IntTensor or LongTensor) - the 1-D tensor containing the indices to index
    :output out(Tensor,optional) - the output tensor
    """
    

      

    >>> x = torch.randn(3, 4)
    >>> x
    tensor([[ 0.1427,  0.0231, -0.5414, -1.0009],
            [-0.4664,  0.2647, -0.1228, -1.1068],
            [-1.1734, -0.6571,  0.7230, -0.6004]])
    >>> indices = torch.tensor([0, 2])
    >>> torch.index_select(x, 0, indices)
    tensor([[ 0.1427,  0.0231, -0.5414, -1.0009],
            [-1.1734, -0.6571,  0.7230, -0.6004]])
    >>> torch.index_select(x, 1, indices)
    tensor([[ 0.1427, -0.5414],
            [-0.4664, -0.1228],
            [-1.1734,  0.7230]])
    

      从这个官方示例可以看出,torch.index是针对某一个维度进行选择的。在示例代码中选择了dim=0(行)。然后使用一个可变长度的indices,选取indices对应的行数。

      二、torch.masked_select

      这个函数就更为简单粗暴了,从名字就可以看出来它是使用一个蒙版来选择Tensor中的值。也容易想到这个mask tensor需要和input tensor的形状保持一致。但是需要注意的是这个函数的输出是一个一维的Tensor,保存了从原始的Tensor中选择出来的所有值。

    >>> x = torch.randn(3, 4)
    >>> x
    tensor([[ 0.3552, -2.3825, -0.8297,  0.3477],
            [-1.2035,  1.2252,  0.5002,  0.6248],
            [ 0.1307, -2.0608,  0.1244,  2.0139]])
    >>> mask = x.ge(0.5)
    >>> mask
    tensor([[False, False, False, False],
            [False, True, True, True],
            [False, False, False, True]])
    >>> torch.masked_select(x, mask)
    tensor([ 1.2252,  0.5002,  0.6248,  2.0139])
    

      小知识点是这里使用了一个ge函数。该函数会逐元素比较array中的值和给定值的大小。然后返回布尔类型的tensor。总之要想用好masked_select,和各种能判断并生成布尔类型tensor的函数搭配起来才是正道。

      三、torch.gather

      这个函数有点绕,建议直接去看官方的文档,说明的比较清楚。

      https://pytorch.org/docs/stable/generated/torch.gather.html?highlight=gather#torch.gather

      

       我对于这个函数的理解是,首先有一个src_tensor和一个index_tensor,index_tensor和src_tensro有相同的维度数,但是在每一个维度上,应该有d(index_tensor) <= d(src_tensor)。

      首先看最简单的情况即index_tensor和src_tensor有完全相同的形状。按照文档上的说明,白话式的解释就是:对于输出矩阵的每一个位置,数据源还是来自src_tensor。只是在指定的那一个维度上,只取响应index的值。拿这个范例来说,[0, 0 ]这个位置,我们的数据源还是src_tensor,只不过数据源变广了,变成了src_tensor[0,:]即第一行中所有列的元素。然后再看index_tensor,这个位置是0。于是这个位置就被赋成饿了src_tensor[0, 0]。

      关于index_tensor比src_tensor小的情况。我最开始的理解就是使用了广播机制首先变换到一样的形状上,但是这是错误的,下面做一个和官方示例相似的小实验测试一下。

    >>>t = torch.tensor([[1, 2], [3, 4]])
    >>>torch.gather(t, 1, torch.tensor([[0, 0]]))
    tensor([[1, 1]])
    

      我们发现输出的结果其实是和index_tensor的形状一致的。也就是舍弃掉index_tensor没有覆盖的位置。

  • 相关阅读:
    vue箭头函数问题
    JS函数知识点梳理
    因tensorflow版本升级ImportError: No module named 'tensorflow.models.rnn'
    数据库优化,以实际SQL入手,带你一步一步走上SQL优化之路!
    在 IntelliJ IDEA 中这样使用 Git,效率提升2倍以上
    百万级高并发mongodb集群性能数十倍提升优化实践
    阿里巴巴Java开发手册正确学习姿势是怎样的?刷新代码规范认知
    50道Redis面试题史上最全,以后面试再也不怕问Redis了
    没想到Spring Boot居然这么耗内存,有点惊讶
    源码角度分析-newFixedThreadPool线程池导致的内存飙升问题
  • 原文地址:https://www.cnblogs.com/chester-cs/p/15438900.html
Copyright © 2011-2022 走看看