zoukankan      html  css  js  c++  java
  • Pytorch的gather用法理解

    先放一张表,可以看成是二维数组

    行(列)索引 索引0 索引1 索引2 索引3
    索引0 0 1 2 3
    索引1 4 5 6 7
    索引2 8 9 10 11
    索引3 12 13 14 15

    看一下下面例子代码:

    针对0维(输出为行形式)

    >>> import torch as t
    >>> a = t.arange(0,16).view(4,4)
    >>> a
    tensor([[ 0,  1,  2,  3],
            [ 4,  5,  6,  7],
            [ 8,  9, 10, 11],
            [12, 13, 14, 15]])
     
    #选取对角线的元素
    >>> index = t.LongTensor([[0,1,2,3]])
    >>> a.gather(0,index)
    tensor([[ 0,  5, 10, 15]])
    

    如何理解结果呢?其实很简单,就是a.gather(0,index)中第一个0已经表明输出结果是行形式(0维),如果第一个是1说明输出结果是列形式(1维),然后按照index = tensor([[0, 1, 2, 3]])顺序作用在行上索引依次为0,1,2,3

    • a[0][0] = 0
    • a[1][1] = 5
    • a[2][2] = 10
    • a[3][3] = 15

    针对0维

    # 选取反对角线上的元素,注意与上面的不同
    >>> index2 = t.LongTensor([[3,2,1,0]])
    >>> a.gather(0,index2)
    tensor([[12,  9,  6,  3]])
    

    如何理解结果呢?同理,按照index = tensor([[3, 2, 1, 0]])顺序作用在行上索引依次为3,2,1,0:

    • a[3][0] = 12
    • a[2][1] = 9
    • a[1][2] = 6
    • a[0][3] = 3

    针对1维(输出为列形式)

    选取对角线的元素

    >>> index3 = t.LongTensor([[0,1,2,3]]).t()
    >>> a.gather(1,index3)
    tensor([[ 0],
            [ 5],
            [10],
            [15]])
    

    如何理解结果呢?同理,按照index = tensor([[0, 1, 2, 3]])顺序作用在列上索引依次为0,1,2,3:

    • a[0][0] = 0
    • a[1][1] = 5
    • a[2][2] = 10
    • a[3][3] = 15

    针对1维

    选取反对角线上的元素

    >>> index4 = t.LongTensor([[3,2,1,0]]).t()
    >>> a.gather(1,index4)
    tensor([[ 3],
            [ 6],
            [ 9],
            [12]])
    

    如何理解结果呢?同理,按照index = tensor([[3, 2, 1, 0]])顺序作用在列上索引依次为3,2,1,0:

    • a[0][3] = 3
    • a[1][2] = 6
    • a[2][1] = 9
    • a[3][0] = 12
  • 相关阅读:
    05 drf源码剖析之认证
    04 drf源码剖析之版本
    03 drf源码剖析之视图
    02 drf源码剖析之快速了解drf
    OA之为用户设置角色和为用户设置权限
    EasyUI之datagrid的使用
    C#之反射
    EasyUI之Layout布局和Tabs页签的使用
    Memcached的使用
    Log4Net的简单使用
  • 原文地址:https://www.cnblogs.com/HongjianChen/p/9451526.html
Copyright © 2011-2022 走看看