zoukankan      html  css  js  c++  java
  • Pytorch中的torch.gather函数

    gather函数的的官方文档:

    
    torch.gather(input, dim, index, out=None) → Tensor
    
        Gathers values along an axis specified by dim.
    
        For a 3-D tensor the output is specified by:
    
        out[i][j][k] = input[index[i][j][k]][j][k]  # dim=0
        out[i][j][k] = input[i][index[i][j][k]][k]  # dim=1
        out[i][j][k] = input[i][j][index[i][j][k]]  # dim=2
    
        Parameters: 
    
            input (Tensor) – The source tensor
            dim (int) – The axis along which to index
            index (LongTensor) – The indices of elements to gather
            out (Tensor, optional) – Destination tensor
    
        Example:
    
        >>> t = torch.Tensor([[1,2],[3,4]])
        >>> torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]]))
         1  1
         4  3
        [torch.FloatTensor of size 2x2]
    

    例子:

    a=t.arange(0,16).view(4,4)
    print(a)
    
    index_1=t.LongTensor([[3,2,1,0]])
    b=a.gather(0,index_1)
    print(b)
    
    index_2=t.LongTensor([[0,1,2,3]]).t()#tensor转置操作:(a)T=a.t()
    c=a.gather(1,index_2)
    print(c)
    
    

    输出如下:

    tensor([[ 0,  1,  2,  3],
            [ 4,  5,  6,  7],
            [ 8,  9, 10, 11],
            [12, 13, 14, 15]])
            
    tensor([[12,  9,  6,  3]])
    
    tensor([[ 0],
            [ 5],
            [10],
            [15]])
    
    

    在上面的例子中,a是一个4×4矩阵:
    1)当维度dim=0,索引index_1为[3,2,1,0]时,此时可将a看成1×4的矩阵,通过index_1对a每列进行行索引:第一列第四行元素为12,第二列第三行元素为9,第三列第二行元素为6,第四列第一行元素为3,即b=[12,9,6,3];
    2)当维度dim=1,索引index_2为[0,1,2,3]T时,此时可将a看成4×1的矩阵,通过index_1对a每行进行列索引:第一行第一列元素为0,第二行第二列元素为5,第三行第三列元素为10,第四行第四列元素为15,即c=[0,5,10,15]T。

    例子二:

    y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])
    y = torch.LongTensor([0, 2])
    y_hat.gather(1, y.view(-1, 1))
    

    输出:

    tensor([[0.1000],
     [0.5000]])
    

    总结:
    gather函数在提取数据时主要靠dimindex这两个参数:
    dim=1时将input看为n×1阶矩阵,index看为k×1阶矩阵,取index每元素对input中每进行索引(如:index某行为[1,3,0],对应的input行元素为[9,8,7,6],提取后的结果为[8,6,9]);
    同理,dim=0时将input看为1×n阶矩阵,index看为1×k阶矩阵,取index每元素对input中每进行索引。gather函数提取后的矩阵阶数和对应的index阶数相同。
    参考:https://blog.csdn.net/weixin_44318872/article/details/103183763?utm_medium=distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-1.add_param_isCf&depth_1-utm_source=distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-1.add_param_isCf

  • 相关阅读:
    好的博客
    left join 后边的on条件 小记
    ElasticSearch构建订单服务的博客
    nacos mysql8.0修改
    maven配置
    idea常用配置
    http状态码
    Web Application:Exploded和Web Application:Archive
    将一个简单远程调用的方式例子改为异步调用 -- 2
    将一个简单远程调用的方式例子改为异步调用
  • 原文地址:https://www.cnblogs.com/Jason66661010/p/13585434.html
Copyright © 2011-2022 走看看