zoukankan      html  css  js  c++  java
  • MMDetection源码解析:Faster RCNN(8)--BBoxHead类

    BBoxHead类继承自nn.Module类,定义在mmdetmodels oi_headsbox_headsbox_head.py中,其作用是输出ROI Pooling的分类和回归值.

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from mmcv.runner import auto_fp16, force_fp32
    from torch.nn.modules.utils import _pair
    
    from mmdet.core import build_bbox_coder, multi_apply, multiclass_nms
    from mmdet.models.builder import HEADS, build_loss
    from mmdet.models.losses import accuracy
    
    
    @HEADS.register_module()
    class BBoxHead(nn.Module):
        """Simplest RoI head, with only two fc layers for classification and
        regression respectively."""
    
        def __init__(self,
                     with_avg_pool=False,
                     with_cls=True,
                     with_reg=True,
                     roi_feat_size=7,
                     in_channels=256,
                     num_classes=80,
                     bbox_coder=dict(
                         type='DeltaXYWHBBoxCoder',
                         target_means=[0., 0., 0., 0.],
                         target_stds=[0.1, 0.1, 0.2, 0.2]),
                     reg_class_agnostic=False,
                     reg_decoded_bbox=False,
                     loss_cls=dict(
                         type='CrossEntropyLoss',
                         use_sigmoid=False,
                         loss_weight=1.0),
                     loss_bbox=dict(
                         type='SmoothL1Loss', beta=1.0, loss_weight=1.0)):
            super(BBoxHead, self).__init__()
            assert with_cls or with_reg
            self.with_avg_pool = with_avg_pool
            self.with_cls = with_cls
            self.with_reg = with_reg
            self.roi_feat_size = _pair(roi_feat_size)
            self.roi_feat_area = self.roi_feat_size[0] * self.roi_feat_size[1]
            self.in_channels = in_channels
            self.num_classes = num_classes
            self.reg_class_agnostic = reg_class_agnostic
            self.reg_decoded_bbox = reg_decoded_bbox
            self.fp16_enabled = False
    
            self.bbox_coder = build_bbox_coder(bbox_coder)
            self.loss_cls = build_loss(loss_cls)
            self.loss_bbox = build_loss(loss_bbox)
    
            in_channels = self.in_channels
            if self.with_avg_pool:
                self.avg_pool = nn.AvgPool2d(self.roi_feat_size)
            else:
                in_channels *= self.roi_feat_area
            if self.with_cls:
                # need to add background class
                self.fc_cls = nn.Linear(in_channels, num_classes + 1)
            if self.with_reg:
                out_dim_reg = 4 if reg_class_agnostic else 4 * num_classes
                self.fc_reg = nn.Linear(in_channels, out_dim_reg)
            self.debug_imgs = None
    
        def init_weights(self):
            # conv layers are already initialized by ConvModule
            if self.with_cls:
                nn.init.normal_(self.fc_cls.weight, 0, 0.01)
                nn.init.constant_(self.fc_cls.bias, 0)
            if self.with_reg:
                nn.init.normal_(self.fc_reg.weight, 0, 0.001)
                nn.init.constant_(self.fc_reg.bias, 0)
    
        @auto_fp16()
        def forward(self, x):
            if self.with_avg_pool:
                x = self.avg_pool(x)
            x = x.view(x.size(0), -1)
            cls_score = self.fc_cls(x) if self.with_cls else None
            bbox_pred = self.fc_reg(x) if self.with_reg else None
            return cls_score, bbox_pred
    
        def _get_target_single(self, pos_bboxes, neg_bboxes, pos_gt_bboxes,
                               pos_gt_labels, cfg):
            num_pos = pos_bboxes.size(0)
            num_neg = neg_bboxes.size(0)
            num_samples = num_pos + num_neg
    
            # original implementation uses new_zeros since BG are set to be 0
            # now use empty & fill because BG cat_id = num_classes,
            # FG cat_id = [0, num_classes-1]
            labels = pos_bboxes.new_full((num_samples, ),
                                         self.num_classes,
                                         dtype=torch.long)
            label_weights = pos_bboxes.new_zeros(num_samples)
            bbox_targets = pos_bboxes.new_zeros(num_samples, 4)
            bbox_weights = pos_bboxes.new_zeros(num_samples, 4)
            if num_pos > 0:
                labels[:num_pos] = pos_gt_labels
                pos_weight = 1.0 if cfg.pos_weight <= 0 else cfg.pos_weight
                label_weights[:num_pos] = pos_weight
                if not self.reg_decoded_bbox:
                    pos_bbox_targets = self.bbox_coder.encode(
                        pos_bboxes, pos_gt_bboxes)
                else:
                    pos_bbox_targets = pos_gt_bboxes
                bbox_targets[:num_pos, :] = pos_bbox_targets
                bbox_weights[:num_pos, :] = 1
            if num_neg > 0:
                label_weights[-num_neg:] = 1.0
    
            return labels, label_weights, bbox_targets, bbox_weights
    
        def get_targets(self,
                        sampling_results,
                        gt_bboxes,
                        gt_labels,
                        rcnn_train_cfg,
                        concat=True):
            pos_bboxes_list = [res.pos_bboxes for res in sampling_results]
            neg_bboxes_list = [res.neg_bboxes for res in sampling_results]
            pos_gt_bboxes_list = [res.pos_gt_bboxes for res in sampling_results]
            pos_gt_labels_list = [res.pos_gt_labels for res in sampling_results]
            labels, label_weights, bbox_targets, bbox_weights = multi_apply(
                self._get_target_single,
                pos_bboxes_list,
                neg_bboxes_list,
                pos_gt_bboxes_list,
                pos_gt_labels_list,
                cfg=rcnn_train_cfg)
    
            if concat:
                labels = torch.cat(labels, 0)
                label_weights = torch.cat(label_weights, 0)
                bbox_targets = torch.cat(bbox_targets, 0)
                bbox_weights = torch.cat(bbox_weights, 0)
            return labels, label_weights, bbox_targets, bbox_weights
    
        @force_fp32(apply_to=('cls_score', 'bbox_pred'))
        def loss(self,
                 cls_score,
                 bbox_pred,
                 rois,
                 labels,
                 label_weights,
                 bbox_targets,
                 bbox_weights,
                 reduction_override=None):
            losses = dict()
            if cls_score is not None:
                avg_factor = max(torch.sum(label_weights > 0).float().item(), 1.)
                if cls_score.numel() > 0:
                    losses['loss_cls'] = self.loss_cls(
                        cls_score,
                        labels,
                        label_weights,
                        avg_factor=avg_factor,
                        reduction_override=reduction_override)
                    losses['acc'] = accuracy(cls_score, labels)
            if bbox_pred is not None:
                bg_class_ind = self.num_classes
                # 0~self.num_classes-1 are FG, self.num_classes is BG
                pos_inds = (labels >= 0) & (labels < bg_class_ind)
                # do not perform bounding box regression for BG anymore.
                if pos_inds.any():
                    if self.reg_decoded_bbox:
                        bbox_pred = self.bbox_coder.decode(rois[:, 1:], bbox_pred)
                    if self.reg_class_agnostic:
                        pos_bbox_pred = bbox_pred.view(
                            bbox_pred.size(0), 4)[pos_inds.type(torch.bool)]
                    else:
                        pos_bbox_pred = bbox_pred.view(
                            bbox_pred.size(0), -1,
                            4)[pos_inds.type(torch.bool),
                               labels[pos_inds.type(torch.bool)]]
                    losses['loss_bbox'] = self.loss_bbox(
                        pos_bbox_pred,
                        bbox_targets[pos_inds.type(torch.bool)],
                        bbox_weights[pos_inds.type(torch.bool)],
                        avg_factor=bbox_targets.size(0),
                        reduction_override=reduction_override)
                else:
                    losses['loss_bbox'] = bbox_pred[pos_inds].sum()
            return losses
    
        @force_fp32(apply_to=('cls_score', 'bbox_pred'))
        def get_bboxes(self,
                       rois,
                       cls_score,
                       bbox_pred,
                       img_shape,
                       scale_factor,
                       rescale=False,
                       cfg=None):
            if isinstance(cls_score, list):
                cls_score = sum(cls_score) / float(len(cls_score))
            scores = F.softmax(cls_score, dim=1) if cls_score is not None else None
    
            if bbox_pred is not None:
                bboxes = self.bbox_coder.decode(
                    rois[:, 1:], bbox_pred, max_shape=img_shape)
            else:
                bboxes = rois[:, 1:].clone()
                if img_shape is not None:
                    bboxes[:, [0, 2]].clamp_(min=0, max=img_shape[1])
                    bboxes[:, [1, 3]].clamp_(min=0, max=img_shape[0])
    
            if rescale and bboxes.size(0) > 0:
                if isinstance(scale_factor, float):
                    bboxes /= scale_factor
                else:
                    scale_factor = bboxes.new_tensor(scale_factor)
                    bboxes = (bboxes.view(bboxes.size(0), -1, 4) /
                              scale_factor).view(bboxes.size()[0], -1)
    
            if cfg is None:
                return bboxes, scores
            else:
                det_bboxes, det_labels = multiclass_nms(bboxes, scores,
                                                        cfg.score_thr, cfg.nms,
                                                        cfg.max_per_img)
    
                return det_bboxes, det_labels
    
        @force_fp32(apply_to=('bbox_preds', ))
        def refine_bboxes(self, rois, labels, bbox_preds, pos_is_gts, img_metas):
            """Refine bboxes during training.
    
            Args:
                rois (Tensor): Shape (n*bs, 5), where n is image number per GPU,
                    and bs is the sampled RoIs per image. The first column is
                    the image id and the next 4 columns are x1, y1, x2, y2.
                labels (Tensor): Shape (n*bs, ).
                bbox_preds (Tensor): Shape (n*bs, 4) or (n*bs, 4*#class).
                pos_is_gts (list[Tensor]): Flags indicating if each positive bbox
                    is a gt bbox.
                img_metas (list[dict]): Meta info of each image.
    
            Returns:
                list[Tensor]: Refined bboxes of each image in a mini-batch.
    
            Example:
                >>> # xdoctest: +REQUIRES(module:kwarray)
                >>> import kwarray
                >>> import numpy as np
                >>> from mmdet.core.bbox.demodata import random_boxes
                >>> self = BBoxHead(reg_class_agnostic=True)
                >>> n_roi = 2
                >>> n_img = 4
                >>> scale = 512
                >>> rng = np.random.RandomState(0)
                >>> img_metas = [{'img_shape': (scale, scale)}
                ...              for _ in range(n_img)]
                >>> # Create rois in the expected format
                >>> roi_boxes = random_boxes(n_roi, scale=scale, rng=rng)
                >>> img_ids = torch.randint(0, n_img, (n_roi,))
                >>> img_ids = img_ids.float()
                >>> rois = torch.cat([img_ids[:, None], roi_boxes], dim=1)
                >>> # Create other args
                >>> labels = torch.randint(0, 2, (n_roi,)).long()
                >>> bbox_preds = random_boxes(n_roi, scale=scale, rng=rng)
                >>> # For each image, pretend random positive boxes are gts
                >>> is_label_pos = (labels.numpy() > 0).astype(np.int)
                >>> lbl_per_img = kwarray.group_items(is_label_pos,
                ...                                   img_ids.numpy())
                >>> pos_per_img = [sum(lbl_per_img.get(gid, []))
                ...                for gid in range(n_img)]
                >>> pos_is_gts = [
                >>>     torch.randint(0, 2, (npos,)).byte().sort(
                >>>         descending=True)[0]
                >>>     for npos in pos_per_img
                >>> ]
                >>> bboxes_list = self.refine_bboxes(rois, labels, bbox_preds,
                >>>                    pos_is_gts, img_metas)
                >>> print(bboxes_list)
            """
            img_ids = rois[:, 0].long().unique(sorted=True)
            assert img_ids.numel() <= len(img_metas)
    
            bboxes_list = []
            for i in range(len(img_metas)):
                inds = torch.nonzero(
                    rois[:, 0] == i, as_tuple=False).squeeze(dim=1)
                num_rois = inds.numel()
    
                bboxes_ = rois[inds, 1:]
                label_ = labels[inds]
                bbox_pred_ = bbox_preds[inds]
                img_meta_ = img_metas[i]
                pos_is_gts_ = pos_is_gts[i]
    
                bboxes = self.regress_by_class(bboxes_, label_, bbox_pred_,
                                               img_meta_)
    
                # filter gt bboxes
                pos_keep = 1 - pos_is_gts_
                keep_inds = pos_is_gts_.new_ones(num_rois)
                keep_inds[:len(pos_is_gts_)] = pos_keep
    
                bboxes_list.append(bboxes[keep_inds.type(torch.bool)])
    
            return bboxes_list
    
        @force_fp32(apply_to=('bbox_pred', ))
        def regress_by_class(self, rois, label, bbox_pred, img_meta):
            """Regress the bbox for the predicted class. Used in Cascade R-CNN.
    
            Args:
                rois (Tensor): shape (n, 4) or (n, 5)
                label (Tensor): shape (n, )
                bbox_pred (Tensor): shape (n, 4*(#class)) or (n, 4)
                img_meta (dict): Image meta info.
    
            Returns:
                Tensor: Regressed bboxes, the same shape as input rois.
            """
            assert rois.size(1) == 4 or rois.size(1) == 5, repr(rois.shape)
    
            if not self.reg_class_agnostic:
                label = label * 4
                inds = torch.stack((label, label + 1, label + 2, label + 3), 1)
                bbox_pred = torch.gather(bbox_pred, 1, inds)
            assert bbox_pred.size(1) == 4
    
            if rois.size(1) == 4:
                new_rois = self.bbox_coder.decode(
                    rois, bbox_pred, max_shape=img_meta['img_shape'])
            else:
                bboxes = self.bbox_coder.decode(
                    rois[:, 1:], bbox_pred, max_shape=img_meta['img_shape'])
                new_rois = torch.cat((rois[:, [0]], bboxes), dim=1)
    
            return new_rois

    主要的函数有:

    (1) __init__():初始化函数,主要参数包括POI Pooling的尺寸大小,输入通道数等等;

    (2) get_targets():计算目标值,通过调用_get_target_single()实现;

    (3) _get_target_single():计算FPN中每一个层次的目标值;

    (4) get_bboxes():输出分类和回归值;

    (5) forward():前向传播;

    (6) loss():计算损失函数值.

  • 相关阅读:
    软件项目版本号的命名规则及格式
    你必须知道的C#的25个基础概念
    Visual C#常用函数和方法集汇总
    web标准下的web开发流程思考
    设计模式(5)>模板方法 小强斋
    设计模式(9)>迭代器模式 小强斋
    设计模式(10)>策略模式 小强斋
    设计模式(8)>代理模式 小强斋
    设计模式(7)>观察者模式 小强斋
    设计模式(7)>观察者模式 小强斋
  • 原文地址:https://www.cnblogs.com/mstk/p/15330915.html
Copyright © 2011-2022 走看看