zoukankan      html  css  js  c++  java
  • 关于Pytorch的二维tensor的gather和scatter_操作用法分析

    看得不明不白(我在下一篇中写了如何理解gather的用法)

    gather是一个比较复杂的操作,对一个2维tensor,输出的每个元素如下:

    out[i][j] = input[index[i][j]][j]  # dim=0
    out[i][j] = input[i][index[i][j]]  # dim=1
    

    二维tensor的gather操作

    针对0轴

    注意index此时的值

    输入

    index = t.LongTensor([[0,1,2,3]])
    print("index = 
    ", index)      #index是2维
    print("index的形状: ",index.shape)  #index形状是(1,4)  
    

    输出

    index = 
     tensor([[0, 1, 2, 3]])
    index的形状:  torch.Size([1, 4])
    

    分割线============

    针对1轴

    注意index此时的值

    输入

    index = t.LongTensor([[0,1,2,3]]).t()  #index是2维
    print("index = 
    ", index)    #index形状是(4,1)
    print("index的形状: ",index.shape)
    

    输出

    index = 
     tensor([[0],
            [1],
            [2],
            [3]])
    index的形状:  torch.Size([4, 1])
    

    分割线===========

    再来看看几个例子

    注意index在以0轴和1轴为标准时的表达式是不一样的。
    b.gather()中取0维时,输出的结果是行形式,取1维时,输出的结果是列形式。

    • b是一个 $ 3 imes4 $ 型的
    >>> import torch as t
    >>> b = t.arange(0,12).view(3,4)
    >>> b
    tensor([[ 0,  1,  2,  3],
            [ 4,  5,  6,  7],
            [ 8,  9, 10, 11]])
    >>> index = t.LongTensor([[0,1,2]])
    
    >>> index
    tensor([[0, 1, 2]])
    
    >>> b.gather(0,index)     #运行失败了
    Traceback (most recent call last):
      File "<stdin>", line 1, in <module>
    RuntimeError: Expected tensor [1 x 3], src [3 x 4] and index [1 x 3] to have the same size apart from dimension 0 at c:
    ew-builder_3win-wheelpytorchatensrc	hgeneric/THTensorMath.cpp:620    
    
    >>> index2 = t.LongTensor([[0,1,2]]).t()
    
    >>> b.gather(1,index2)  #运行成功了
    tensor([[ 0],
            [ 5],
            [10]])
    
    >>> index3 = t.LongTensor([[0,1,2,3]]).t()
    
    >>> b.gather(1,index3)  #运行失败了
    Traceback (most recent call last):
      File "<stdin>", line 1, in <module>
    RuntimeError: Expected tensor [4 x 1], src [3 x 4] and index [4 x 1] to have the same size apart from dimension 1 at c:
    ew-builder_3win-wheelpytorchatensrc	hgeneric/THTensorMath.cpp:620
    
    • b是一个 $ 6 imes6 $ 型的
    >>> import torch as t
    >>> b = t.arange(0,36).view(6,6)
    >>> b
    tensor([[ 0,  1,  2,  3,  4,  5],
            [ 6,  7,  8,  9, 10, 11],
            [12, 13, 14, 15, 16, 17],
            [18, 19, 20, 21, 22, 23],
            [24, 25, 26, 27, 28, 29],
            [30, 31, 32, 33, 34, 35]])
    
    >>> index = t.LongTensor([[0,1,2,3,4,5,6]])
    >>> b.gather(0,index)     #运行失败了
    Traceback (most recent call last):
      File "<stdin>", line 1, in <module>
    RuntimeError: Expected tensor [1 x 7], src [6 x 6] and index [1 x 7] to have the same size apart from dimension 0 at c:
    ew-builder_3win-wheelpytorchatensrc	hgeneric/THTensorMath.cpp:620  
    
    >>> index = t.LongTensor([[0,1,2,3,4,5]])
    >>> b.gather(0,index)    #运行成功了
    tensor([[ 0,  7, 14, 21, 28, 35]])
    >>> b.gather(1,index)
    Traceback (most recent call last):
      File "<stdin>", line 1, in <module>
    RuntimeError: Expected tensor [1 x 6], src [6 x 6] and index [1 x 6] to have the same size apart from dimension 1 at c:
    ew-builder_3win-wheelpytorchatensrc	hgeneric/THTensorMath.cpp:620
    
    >>> index2 = t.LongTensor([[0,1,2,3,4,5]]).t()  
    >>> b.gather(1,index2)     #运行成功了
    tensor([[ 0],     
            [ 7],
            [14],
            [21],
            [28],
            [35]])
    
    >>> index3 = t.LongTensor([[0,1,2,3,4]]).t()
    >>> b.gather(1,index3)
    Traceback (most recent call last):
      File "<stdin>", line 1, in <module>
    RuntimeError: Expected tensor [5 x 1], src [6 x 6] and index [5 x 1] to have the same size apart from dimension 1 at c:
    ew-builder_3win-wheelpytorchatensrc	hgeneric/THTensorMath.cpp:620  
    
    >>> index4 = t.LongTensor([[0,1,2,3,4]])
    >>> b.gather(0,index4)
    Traceback (most recent call last):
      File "<stdin>", line 1, in <module>
    RuntimeError: Expected tensor [1 x 5], src [6 x 6] and index [1 x 5] to have the same size apart from dimension 0 at c:
    ew-builder_3win-wheelpytorchatensrc	hgeneric/THTensorMath.cpp:620
    
    

    与gather相对应的逆操作是scatter_,gather把数据从input中按index取出,而scatter_是把取出的数据再放回去。注意scatter_函数是inplace操作。



    与gather相对应的逆操作是scatter_,gather把数据从input中按index取出,而scatter_是把取出的数据再放回去。注意scatter_函数是inplace操作。

    out = input.gather(dim, index)
    -->近似逆操作
    out = Tensor()
    out.scatter_(dim, index)
    

    根据StackOverflow上的问题修改代码如下:
    输入

    # 把两个对角线元素放回去到指定位置
    c = t.zeros(4,4)
    c.scatter_(1, index, b.float())
    

    输出

    tensor([[ 0.,  0.,  0.,  3.],
            [ 0.,  5.,  6.,  0.],
            [ 0.,  9., 10.,  0.],
            [12.,  0.,  0., 15.]])
    
  • 相关阅读:
    求解整数集合的交集(腾讯笔试)
    关于屏幕适配之比例布局
    (转)注册JNI函数的两种方式
    正则表达式记录
    当年一个简单可用的多线程断点续传类
    最近用到的几个工具方法
    Android中包含List成员变量的Parcel以及Parcel嵌套写法示例
    java实现计算MD5
    一个用于去除状态栏和虚拟导航栏的BaseActivity
    MVP的模板
  • 原文地址:https://www.cnblogs.com/HongjianChen/p/9450987.html
Copyright © 2011-2022 走看看