zoukankan      html  css  js  c++  java
  • pytorch :使用两次sort函数(排序)找出矩阵每个元素在升序或降序排列中的位置

    在SSD的代码中经常有见到如下的操作:

            _, idx = flt[:, :, 0].sort(1, descending=True)#大小为[batch size, num_classes*top_k]
            _, rank = idx.sort(1)#再对索引升序排列,得到其索引作为排名rank

    其作用是什么呢?举个例子:

    import torch
    a = torch.randn(3,4)
    print(a)
    print()
    
    i, idx = a.sort(dim=1, descending=True)
    print(i)
    print(idx)
    print()
    
    j, rank = idx.sort(dim=1)
    print(rank)

    返回:

    tensor([[ 2.3326,  0.0275, -0.0799,  0.4156],
            [-2.2066,  1.7997, -2.2767,  0.4704],
            [-0.6980,  0.2285,  1.0018, -0.8874]])
    
    tensor([[ 2.3326,  0.4156,  0.0275, -0.0799],
            [ 1.7997,  0.4704, -2.2066, -2.2767],
            [ 1.0018,  0.2285, -0.6980, -0.8874]])
    tensor([[0, 3, 1, 2],
            [1, 3, 0, 2],
            [2, 1, 0, 3]])
    
    tensor([[0, 2, 3, 1],
            [2, 0, 3, 1],
            [2, 1, 0, 3]])

    其实就是可以通过最后的rank看出对应位置的值的排序位置,这里是降序,所以索引0表示最高

    以rank的第一行第一列的值0为例,其表示对应的a中第一行第一列的值2.3326是第一行中最大的,因为设置了dim=1;第一行第三列的值3为例,其表示对应的a中第一行第三列的值-0.0799是第一行中最小的

    起到了对应位置排名的作用

  • 相关阅读:
    视觉里程计VO-直接法
    Linux安装libcholmod-dev找不到的解决方法
    Levenberg-Marquadt Method
    Gauss-Newton Method
    CMake
    方差 标准差 协方差
    SFM
    矩阵分解
    kvm学习笔记
    python学习笔记
  • 原文地址:https://www.cnblogs.com/wanghui-garcia/p/12982732.html
Copyright © 2011-2022 走看看