zoukankan      html  css  js  c++  java
  • torch 数组按顺序进行编号,用在hard negative mining

    在hard negative mining中,要负样本的得分进行排序,选取得分最高的参与训练。这就涉及到选取top-k个负样本或掩膜,因此要对数组按顺序进行编号。
    在pytorch中,torch.sort()函数可以返回排序好的数组和元素的索引,用两次torch.sort()函数即可以得到数组元素的编号。

    b = torch.rand(5)
    print(b)
    
    # 从小到大编号
    _, index = b.sort()
    _, order = index.sort()
    print(order)
    
    # 从大到小进编号
    _, index = b.sort(descending=True)
    _, order = index.sort()
    print(order)
    

    附上一段hard example mining

    def hard_negative_mining(loss, labels, neg_pos_ratio):
        """
        It used to suppress the presence of a large number of negative prediction.
        It works on image level not batch level.
        For any example/image, it keeps all the positive predictions and
         cut the number of negative predictions to make sure the ratio
         between the negative examples and positive examples is no more
         the given ratio for an image.
    
        Args:
            loss (N, num_priors): the loss for each example.
            labels (N, num_priors): the labels.
            neg_pos_ratio:  the ratio between the negative examples and positive examples.
        """
        pos_mask = labels > 0
        num_pos = pos_mask.long().sum(dim=1, keepdim=True)
        num_neg = num_pos * neg_pos_ratio
    
        loss[pos_mask] = -math.inf
        _, indexes = loss.sort(dim=1, descending=True)
        _, orders = indexes.sort(dim=1)
        neg_mask = orders < num_neg
        return pos_mask | neg_mask
    
  • 相关阅读:
    (打包报错)AS打包报错:Generate Signed APK: Errors while building APK. You can find the errors in the 'Messages' view.
    NABCD
    石家庄地铁站项目最终总结报告
    团队冲刺2.7
    团队冲刺2.6
    团队冲刺2.5
    团队冲刺2.4
    团队冲刺2.3
    用户体验评价——win10自带微软拼音输入法
    团队冲刺2.2
  • 原文地址:https://www.cnblogs.com/zi-wang/p/12292585.html
Copyright © 2011-2022 走看看