zoukankan      html  css  js  c++  java
  • Iou及NMS实现

    Iou

    Jaccard系数(Jaccard index)可以衡量两个集合的相似度。给定集合\(\mathcal{A}\)\(\mathcal{B}\),它们的Jaccard系数即二者交集大小除以二者并集大小:

    \[J(\mathcal{A},\mathcal{B}) = \frac{\left|\mathcal{A} \cap \mathcal{B}\right|}{\left| \mathcal{A} \cup \mathcal{B}\right|}. \]

    通常将Jaccard系数称为交并比(Intersection over Union,IoU),即两个边界框相交面积与相并面积之比,如图所示。交并比的取值范围在0和1之间:0表示两个边界框无重合像素,1表示两个边界框相等。

    点击展开:Iou实现
    # 参考https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Object-Detection/blob/master/utils.py#L340
    def find_intersection(set_1, set_2):
        """
        Find the intersection of every box combination between two sets of boxes that are in boundary coordinates.
        :param set_1: set 1, a tensor of dimensions (n1, 4)
        :param set_2: set 2, a tensor of dimensions (n2, 4)
        :return: intersection of each of the boxes in set 1 with respect to each of the boxes in set 2, a tensor of dimensions (n1, n2)
        """
    
        # PyTorch auto-broadcasts singleton dimensions
        lower_bounds = torch.max(set_1[:, :2].unsqueeze(1), set_2[:, :2].unsqueeze(0))  # (n1, n2, 2)
        upper_bounds = torch.min(set_1[:, 2:].unsqueeze(1), set_2[:, 2:].unsqueeze(0))  # (n1, n2, 2)
        intersection_dims = torch.clamp(upper_bounds - lower_bounds, min=0)  # (n1, n2, 2)
        return intersection_dims[:, :, 0] * intersection_dims[:, :, 1]  # (n1, n2)
    
    
    def find_jaccard_overlap(set_1, set_2):
        """
        Find the Jaccard Overlap (IoU) of every box combination between two sets of boxes that are in boundary coordinates.
        :param set_1: set 1, a tensor of dimensions (n1, 4)
        :param set_2: set 2, a tensor of dimensions (n2, 4)
        :return: Jaccard Overlap of each of the boxes in set 1 with respect to each of the boxes in set 2, a tensor of dimensions (n1, n2)
        """
    
        # Find intersections
        intersection = find_intersection(set_1, set_2)  # (n1, n2)
    
        # Find areas of each box in both sets
        areas_set_1 = (set_1[:, 2] - set_1[:, 0]) * (set_1[:, 3] - set_1[:, 1])  # (n1)
        areas_set_2 = (set_2[:, 2] - set_2[:, 0]) * (set_2[:, 3] - set_2[:, 1])  # (n2)
    
        # Find the union
        # PyTorch auto-broadcasts singleton dimensions
        union = areas_set_1.unsqueeze(1) + areas_set_2.unsqueeze(0) - intersection  # (n1, n2)
    
        return intersection / union  # (n1, n2)
    

    NMS

    在模型预测阶段,我们先为图像生成多个锚框,并为这些锚框一一预测类别和偏移量。随后,我们根据锚框及其预测偏移量得到预测边界框。当锚框数量较多时,同一个目标上可能会输出较多相似的预测边界框。为了使结果更加简洁,我们可以移除相似的预测边界框。常用的方法叫作非极大值抑制(non-maximum suppression,NMS)。

    NMS工作原理:
    对于一个预测边界框\(B\),模型会计算各个类别的预测概率。设其中最大的预测概率为\(p\),该概率所对应的类别即\(B\)的预测类别。我们也将\(p\)称为预测边界框\(B\)的置信度。在同一图像上,我们将预测类别非背景的预测边界框按置信度从高到低排序,得到列表\(L\)。从\(L\)中选取置信度最高的预测边界框\(B_1\)作为基准,将所有与\(B_1\)的交并比大于某阈值的非基准预测边界框从\(L\)中移除。这里的阈值是预先设定的超参数。此时,\(L\)保留了置信度最高的预测边界框并移除了与其相似的其他预测边界框。接下来,从\(L\)中选取置信度第二高的预测边界框\(B_2\)作为基准,将所有与\(B_2\)的交并比大于某阈值的非基准预测边界框从\(L\)中移除。重复这一过程,直到\(L\)中所有的预测边界框都曾作为基准。此时\(L\)中任意一对预测边界框的交并比都小于阈值。最终,输出列表\(L\)中的所有预测边界框。

    • step1 预测框按照置信度排序,挑选置信度最高的边界框\(A\)
    • step2 计算其他预测框和\(A\)的Iou值
    • step3 根据Iou值过滤超过阈值(超参数)的预测框
    • step4 找剩下中置信度最高的边界框,然后重复123步,直到所有Iou值都满足要求
    点击展开:NMS实现
    from collections import namedtuple
    Pred_BB_Info = namedtuple("Pred_BB_Info", ["index", "class_id", "confidence", "xyxy"])
    
    def non_max_suppression(bb_info_list, nms_threshold = 0.5):
        """
        非极大抑制处理预测的边界框
        Args:
            bb_info_list: Pred_BB_Info的列表, 包含预测类别、置信度等信息
            nms_threshold: 阈值
        Returns:
            output: Pred_BB_Info的列表, 只保留过滤后的边界框信息
        """
        output = []
        # 先根据置信度从高到低排序
        sorted_bb_info_list = sorted(bb_info_list, key = lambda x: x.confidence, reverse=True)
    
        while len(sorted_bb_info_list) != 0:
            best = sorted_bb_info_list.pop(0)
            output.append(best)
            
            if len(sorted_bb_info_list) == 0:
                break
    
            bb_xyxy = []
            for bb in sorted_bb_info_list:
                bb_xyxy.append(bb.xyxy)
            
            iou = compute_jaccard(torch.tensor([best.xyxy]), 
                                  torch.tensor(bb_xyxy))[0] # shape: (len(sorted_bb_info_list), )
            
            n = len(sorted_bb_info_list)
            sorted_bb_info_list = [sorted_bb_info_list[i] for i in range(n) if iou[i] <= nms_threshold]
        return output
    
    def MultiBoxDetection(cls_prob, loc_pred, anchor, nms_threshold = 0.5):
        """
        # 按照「9.4.1. 生成多个锚框」所讲的实现, anchor表示成归一化(xmin, ymin, xmax, ymax).
        https://zh.d2l.ai/chapter_computer-vision/anchor.html
        Args:
            cls_prob: 经过softmax后得到的各个锚框的预测概率, shape:(bn, 预测总类别数+1, 锚框个数)
            loc_pred: 预测的各个锚框的偏移量, shape:(bn, 锚框个数*4)
            anchor: MultiBoxPrior输出的默认锚框, shape: (1, 锚框个数, 4)
            nms_threshold: 非极大抑制中的阈值
        Returns:
            所有锚框的信息, shape: (bn, 锚框个数, 6)
            每个锚框信息由[class_id, confidence, xmin, ymin, xmax, ymax]表示
            class_id=-1 表示背景或在非极大值抑制中被移除了
        """
        assert len(cls_prob.shape) == 3 and len(loc_pred.shape) == 2 and len(anchor.shape) == 3
        bn = cls_prob.shape[0]
        
        def MultiBoxDetection_one(c_p, l_p, anc, nms_threshold = 0.5):
            """
            MultiBoxDetection的辅助函数, 处理batch中的一个
            Args:
                c_p: (预测总类别数+1, 锚框个数)
                l_p: (锚框个数*4, )
                anc: (锚框个数, 4)
                nms_threshold: 非极大抑制中的阈值
            Return:
                output: (锚框个数, 6)
            """
            pred_bb_num = c_p.shape[1]
            anc = (anc + l_p.view(pred_bb_num, 4)).detach().cpu().numpy() # 加上偏移量
            
            confidence, class_id = torch.max(c_p, 0)
            confidence = confidence.detach().cpu().numpy()
            class_id = class_id.detach().cpu().numpy()
            
            pred_bb_info = [Pred_BB_Info(
                                index = i,
                                class_id = class_id[i] - 1, # 正类label从0开始
                                confidence = confidence[i],
                                xyxy=[*anc[i]]) # xyxy是个列表
                            for i in range(pred_bb_num)]
            
            # 正类的index
            obj_bb_idx = [bb.index for bb in non_max_suppression(pred_bb_info, nms_threshold)]
            
            output = []
            for bb in pred_bb_info:
                output.append([
                    (bb.class_id if bb.index in obj_bb_idx else -1.0),
                    bb.confidence,
                    *bb.xyxy
                ])
                
            return torch.tensor(output) # shape: (锚框个数, 6)
        
        batch_output = []
        for b in range(bn):
            batch_output.append(MultiBoxDetection_one(cls_prob[b], loc_pred[b], anchor[0], nms_threshold))
        
        return torch.stack(batch_output)
    

    参考链接

    a-PyTorch-Tutorial-to-Object-Detection
    Dive-into-DL-PyTorch

  • 相关阅读:
    用MySQL的注意事项
    在win下mysql备份恢复命令概述
    SQL查询结果集对注入的影响与利用
    DIV CSS完美兼容IE6/IE7/FF的通用方法
    使用css实现透视的效果
    ASP.NET几个性能优化的方法
    ASP.NET实现页面传值的几种方法
    ASP.NET配置文件Web.config 详细解释
    黑客域名劫持攻击详细步骤
    FCKeditor的几点修改小结
  • 原文地址:https://www.cnblogs.com/xiaxuexiaoab/p/15679133.html
Copyright © 2011-2022 走看看