zoukankan      html  css  js  c++  java
  • Pytorch取最小或最大的张量索引

    Pytorch中根据索引取张量有很多方法,比如index_select和masked_select,和gt,ge等配合食用,但如果需要取出最小几个或最大几个张量的索引,则需要动手写一下

    a = torch.tensor([2,3,1,5])
    y,_ = torch.sort(a)
    mask = a.gt(y[0])
    index = []
    mask_list = (mask== False).nonzero()
    index = [int(i) for i in mask_list]
    index
    
    >>>[2]

    先对张量做排序,然后求出原张量大于和小于等 最小值y[0]的掩码,大于为True,小于等于为False,然后用nonezero()方法就可以求出掩码中False的索引,done

  • 相关阅读:
    java 学习帮助
    权限
    ftp mybatis
    注解
    hadoop english
    userDao
    发布订阅模式 和委托
    webservice
    rabbitMq视频教程
    blog url.txt
  • 原文地址:https://www.cnblogs.com/yqpy/p/12561779.html
Copyright © 2011-2022 走看看