zoukankan      html  css  js  c++  java
  • PyTorch tensor的scatter_函数

    TORCH.TENSOR.SCATTER_

    Tensor.scatter_(dim, index, src, reduce=None) → Tensor

    把src里面的元素按照index和dim参数给出的条件,放置到目标tensor里面,在这里是self。下面为了讨论方便,目标tensor和self在交换使用的时候,请大家知道,在这里指的是同一个tensor.

    注意:这里self, index, src三个张量的纬度必须是一致的(但每个纬度上的size不一定一致,请大家体会)。
    只有src是个例外,可以是标量,即单个数字。
    这个时候,就是把这单个数字,根据参数的条件, 放置到self的不同位置。
    

    那么怎么放呢?根据PyTorch的文档,对于一个3-D的tensor,放置方法如下:

    self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
    self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
    self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2
    

    由上面的公式很容易推断出,对于一个2-D的tensor,放置方法如下:

    self[index[i][j]][j] = src[i][j]  # if dim == 0
    self[i][index[i][j]] = src[i][j]  # if dim == 1
    

    对于一个1-D的tensor,放置方法如下:

    self[index[i]] = src[i]  # if dim == 0	
    

    是不是有点晕?我们来解释一下。

    1. 当dim为0的时候

    我们把src里面的元素放置到self里面的时候,假设是放置src的第[i][j][k]个元素,那么放置到self里面的位置(三个纬度的值)分别如下:

    • index[i][j][k]
    • j
    • k

    对于第一个纬度的位置,就是把i,替换为index[i][j][k]

    那么这里有个问题,如果index的size比src的size要小的话,怎么办? 那就是对于在index里面,找不到的值,就不再处理,self里面原来是什么还是什么。


    为了更加方便说明,这里假设src是1-D的,即一个1维数组,那么dim只有一个值可以设置,即0(当然也可以说是有两个值,-1也是可以的,但是-1和0实际上指的都是第一个纬度)。那么这个时候self和index按照上面的规则,也必须都是一维的(参见上面的注意)。那么我们直接来看一段示例代码和输出来进行解释:

    a = torch.arange(1, 6).long()
    print(a)
    i = torch.LongTensor([4,3,2])
    t = torch.zeros(10).long()
    t.scatter_(0, i, a)
    print(t)
    

    输出为:

    tensor([1, 2, 3, 4, 5])
    tensor([0, 0, 3, 2, 1, 0, 0, 0, 0, 0])
    

    可以看到这里,源tensor(a)是一个一维,包含5个元素。目标tensor(t),在这里是一个10个元素的tensor,为了大家看得方便,我们先把所有元素设置为0,然后再把源tensor里面的元素搬过来放进目标tensor里面的时候,就很容易看到,被index tensor里面的信息所影响到的元素是非0的,如果没受到影响的是0。

    这里源tensor只有5个元素,那么都搬过来,目标tensor(t)里面的元素也还是有10-5=5个元素是不会受到影响的,即为0。

    那么为什么上面看到目标tensor里面的非0元素的个数只有3个,而不是5个(等于源tensor的个数)? 回顾一下对于3-D的tensor,当dim=0的时候,元素设置的公式:

    self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
    

    显然,对于1-D的tensor,上面的公式简化为:

    self[index[i]] = src[i]  # if dim == 0	
    

    因为这里index只有三个元素[4,3,2],那么意味着,再把源tensor(a)里面的5个元素放置到目标tensor(t)的过程中,只有i取值为0,1,2的,才能使用index里面的值,其余2个(在a里面的位置分别为4,5),就不再般src里面的元素了。我们来逐个元素说明一下:

    • 当 i == 0时,self[index[0]] = src[0],即self[4] = src[0],也就是把src里面的第1个元素设置到self的第4个元素,这里src[0] 即是 a[0],是1,而self[4],即t[4]被设置为了1.
    • 当 i == 1时,self[index[1]] = src[1],即self[3] = src[1],也就是把src里面的第2个元素设置到self的第3个元素,这里src[1] 即是 a[1],是2,而self[3],即t[3]被设置为了2.
    • 当 i == 2时,self[index[2]] = src[2],即self[2] = src[2],也就是把src里面的第3个元素设置到self的第2个元素,这里src[2] 即是 a[2],是3,而self[2],即t[2]被设置为了3.
    • 当 i == 3 和4的是,index里面已经没有对应的数值了,这些元素就不处理了。

    2. 当dim为1的时候

    说明,src,目标tensor和index都至少是2-D的,如果设置dim = 1,将会导致PyTorch报错。错误信息如下(对于1-D的index):

    IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)

    对于一个3-D的tensor,我们把src里面的元素放置到self里面的时候,假设是放置src的第[i][j][k]个元素,那么放置到self里面的位置(三个纬度的值)分别如下:

    • i
    • index[i][j][k]
    • k

    对于第一个纬度的位置,就是i,元素在src里面的位置是什么,在self里面也是相同的。

    对于第二个纬度的位置,就是j,元素在self里面的位置变成了index[i][j][k]。

    那么同样地,如果index的size比src的size要小的话,怎么办?那就是对于在index里面,找不到的值,就不在处理,即self里面是什么值还是什么值,不会变化。

    对于一个2-D的tensor,我们把src里面的元素放置到self里面的时候,假设是放置src的第[i][j]个元素,那么放置到self里面的位置(个纬度的值)分别如下:

    • i
    • index[i][j]

    为了更加方便理解,这里假设src是2-D的,即一个2维数组,且dim==1的情况下:

    a = torch.arange(1, 11).long().reshape(2,5)
    print(a)
    i = torch.LongTensor([[4], [3]])
    t = torch.zeros(10).long().reshape(2, 5)
    t.scatter_(1, i, a)
    print(t)
    

    输出如下:

    tensor([[ 1,  2,  3,  4,  5],
            [ 6,  7,  8,  9, 10]])
    tensor([[0, 0, 0, 0, 1],
            [0, 0, 0, 6, 0]])
    

    在这里,index里面只有两个元素,那么也就是最终会有两个元素的值从src里面取出,设置到a里面去。在index里面仅有的两个元素是index[0][0]和index[1][0],这两个对应的src的元素是a[0][0]和a[1][0],对应的目的tensor(t)里面的t[0][index[0][0]]和t[1][index[1][0]]元素,即t[0][4]将会被设置为a[0][0],t[1][3]将会被设置为a[1][0],即:

    • t[0][4] = 1
    • t[1][3] = 6

    其他目的tensor(t)里面的值都不会变。

    3. 当dim为2的时候

    大家可以按照上面说明的规则,自己进行推导,就不在这里赘述了。

    总结:

    scatter或者scatter_函数的作用就是把src里面的元素按照index和dim参数给出的条件,放置到目标tensor里面去。index有几个元素,就会有几个元素被从src里面放到目标tensor里面,其余目标tensor里面的元素不受影响。

  • 相关阅读:
    Web API 强势入门指南
    毫秒必争,前端网页性能最佳实践
    Windbg Extension NetExt 使用指南 【3】 ---- 挖掘你想要的数据 Managed Heap
    Windbg Extension NetExt 使用指南 【2】 ---- NetExt 的基本命令介绍
    Windbg Extension NetExt 使用指南 【1】 ---- NetExt 介绍
    WCF : 修复 Security settings for this service require Windows Authentication but it is not enabled for the IIS application that hosts this service 问题
    透过WinDBG的视角看String
    Microsoft Azure Web Sites应用与实践【4】—— Microsoft Azure网站的“后门”
    企业IT管理员IE11升级指南【17】—— F12 开发者工具
    WCF : 如何将NetTcpBinding寄宿在IIS7上
  • 原文地址:https://www.cnblogs.com/jizhao/p/15515413.html
Copyright © 2011-2022 走看看