zoukankan      html  css  js  c++  java
  • MMDetection源码解析:Faster RCNN(3)--RPN Head类

    Faster RCNN配置文件faster_rcnn_r50_fpn.py中的

        backbone=dict(
            type='ResNet',
            depth=50,
            num_stages=4,
            out_indices=(0, 1, 2, 3),
            frozen_stages=1,
            norm_cfg=dict(type='BN', requires_grad=True),
            norm_eval=True,
            style='pytorch'),

    设置了Backbone为ResNet.

        neck=dict(
            type='FPN',
            in_channels=[256, 512, 1024, 2048],
            out_channels=256,
            num_outs=5),

    设置了Neck为FPN.Backbone和Neck比较简单,就不详细介绍了,详细介绍一下RPN Head.

        rpn_head=dict(
            type='RPNHead',
            in_channels=256,
            feat_channels=256,
            anchor_generator=dict(
                type='AnchorGenerator',
                scales=[8],
                ratios=[0.5, 1.0, 2.0],
                strides=[4, 8, 16, 32, 64]),
            bbox_coder=dict(
                type='DeltaXYWHBBoxCoder',
                target_means=[.0, .0, .0, .0],
                target_stds=[1.0, 1.0, 1.0, 1.0]),
            loss_cls=dict(
                type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
            loss_bbox=dict(type='L1Loss', loss_weight=1.0)),

    以上文件设置了RPN Head为RPNHead(RPNHead定义在mmdetection/mmdet/models/dense_heads/rpn_head.py文件里面);同时定义了anchor_generator为AnchorGenerator(AnchorGenerator定义在mmdetection/mmdet/core/anchor/anchor_generator.py里面),指定了Anchor生成的方式;bbox_coder为DeltaXYWHBBoxCoder(DeltaXYWHBBoxCoder定义在mmdetection/mmdet/core/bbox/coder/delta_xywh_bbox_coder.py里面),指定了BBox的编码解码方式;loss_cls和loss_bbox设置了BBox的分类损失和回归损失.

    rpn_head.py文件内容如下:

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from mmcv.cnn import normal_init
    from mmcv.ops import batched_nms
    
    from ..builder import HEADS
    from .anchor_head import AnchorHead
    from .rpn_test_mixin import RPNTestMixin
    
    
    @HEADS.register_module()
    class RPNHead(RPNTestMixin, AnchorHead):
        """RPN head.
    
        Args:
            in_channels (int): Number of channels in the input feature map.
        """  # noqa: W605
    
        def __init__(self, in_channels, **kwargs):
            super(RPNHead, self).__init__(1, in_channels, **kwargs)
    
        def _init_layers(self):
            """Initialize layers of the head."""
            self.rpn_conv = nn.Conv2d(
                self.in_channels, self.feat_channels, 3, padding=1)
            self.rpn_cls = nn.Conv2d(self.feat_channels,
                                     self.num_anchors * self.cls_out_channels, 1)
            self.rpn_reg = nn.Conv2d(self.feat_channels, self.num_anchors * 4, 1)
    
        def init_weights(self):
            """Initialize weights of the head."""
            normal_init(self.rpn_conv, std=0.01)
            normal_init(self.rpn_cls, std=0.01)
            normal_init(self.rpn_reg, std=0.01)
    
        def forward_single(self, x):
            """Forward feature map of a single scale level."""
            x = self.rpn_conv(x)
            x = F.relu(x, inplace=True)
            rpn_cls_score = self.rpn_cls(x)
            rpn_bbox_pred = self.rpn_reg(x)
            return rpn_cls_score, rpn_bbox_pred
    
        def loss(self,
                 cls_scores,
                 bbox_preds,
                 gt_bboxes,
                 img_metas,
                 gt_bboxes_ignore=None):
            """Compute losses of the head.
    
            Args:
                cls_scores (list[Tensor]): Box scores for each scale level
                    Has shape (N, num_anchors * num_classes, H, W)
                bbox_preds (list[Tensor]): Box energies / deltas for each scale
                    level with shape (N, num_anchors * 4, H, W)
                gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
                    shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
                img_metas (list[dict]): Meta information of each image, e.g.,
                    image size, scaling factor, etc.
                gt_bboxes_ignore (None | list[Tensor]): specify which bounding
                    boxes can be ignored when computing the loss.
    
            Returns:
                dict[str, Tensor]: A dictionary of loss components.
            """
            losses = super(RPNHead, self).loss(
                cls_scores,
                bbox_preds,
                gt_bboxes,
                None,
                img_metas,
                gt_bboxes_ignore=gt_bboxes_ignore)
            return dict(
                loss_rpn_cls=losses['loss_cls'], loss_rpn_bbox=losses['loss_bbox'])
    
        def _get_bboxes_single(self,
                               cls_scores,
                               bbox_preds,
                               mlvl_anchors,
                               img_shape,
                               scale_factor,
                               cfg,
                               rescale=False):
            """Transform outputs for a single batch item into bbox predictions.
    
            Args:
                cls_scores (list[Tensor]): Box scores for each scale level
                    Has shape (num_anchors * num_classes, H, W).
                bbox_preds (list[Tensor]): Box energies / deltas for each scale
                    level with shape (num_anchors * 4, H, W).
                mlvl_anchors (list[Tensor]): Box reference for each scale level
                    with shape (num_total_anchors, 4).
                img_shape (tuple[int]): Shape of the input image,
                    (height, width, 3).
                scale_factor (ndarray): Scale factor of the image arange as
                    (w_scale, h_scale, w_scale, h_scale).
                cfg (mmcv.Config): Test / postprocessing configuration,
                    if None, test_cfg would be used.
                rescale (bool): If True, return boxes in original image space.
    
            Returns:
                Tensor: Labeled boxes in shape (n, 5), where the first 4 columns
                    are bounding box positions (tl_x, tl_y, br_x, br_y) and the
                    5-th column is a score between 0 and 1.
            """
            cfg = self.test_cfg if cfg is None else cfg
            # bboxes from different level should be independent during NMS,
            # level_ids are used as labels for batched NMS to separate them
            level_ids = []
            mlvl_scores = []
            mlvl_bbox_preds = []
            mlvl_valid_anchors = []
            for idx in range(len(cls_scores)):
                rpn_cls_score = cls_scores[idx]
                rpn_bbox_pred = bbox_preds[idx]
                assert rpn_cls_score.size()[-2:] == rpn_bbox_pred.size()[-2:]
                rpn_cls_score = rpn_cls_score.permute(1, 2, 0)
                if self.use_sigmoid_cls:
                    rpn_cls_score = rpn_cls_score.reshape(-1)
                    scores = rpn_cls_score.sigmoid()
                else:
                    rpn_cls_score = rpn_cls_score.reshape(-1, 2)
                    # We set FG labels to [0, num_class-1] and BG label to
                    # num_class in RPN head since mmdet v2.5, which is unified to
                    # be consistent with other head since mmdet v2.0. In mmdet v2.0
                    # to v2.4 we keep BG label as 0 and FG label as 1 in rpn head.
                    scores = rpn_cls_score.softmax(dim=1)[:, 0]
                rpn_bbox_pred = rpn_bbox_pred.permute(1, 2, 0).reshape(-1, 4)
                anchors = mlvl_anchors[idx]
                if cfg.nms_pre > 0 and scores.shape[0] > cfg.nms_pre:
                    # sort is faster than topk
                    # _, topk_inds = scores.topk(cfg.nms_pre)
                    ranked_scores, rank_inds = scores.sort(descending=True)
                    topk_inds = rank_inds[:cfg.nms_pre]
                    scores = ranked_scores[:cfg.nms_pre]
                    rpn_bbox_pred = rpn_bbox_pred[topk_inds, :]
                    anchors = anchors[topk_inds, :]
                mlvl_scores.append(scores)
                mlvl_bbox_preds.append(rpn_bbox_pred)
                mlvl_valid_anchors.append(anchors)
                level_ids.append(
                    scores.new_full((scores.size(0), ), idx, dtype=torch.long))
    
            scores = torch.cat(mlvl_scores)
            anchors = torch.cat(mlvl_valid_anchors)
            rpn_bbox_pred = torch.cat(mlvl_bbox_preds)
            proposals = self.bbox_coder.decode(
                anchors, rpn_bbox_pred, max_shape=img_shape)
            ids = torch.cat(level_ids)
    
            if cfg.min_bbox_size > 0:
                w = proposals[:, 2] - proposals[:, 0]
                h = proposals[:, 3] - proposals[:, 1]
                valid_inds = torch.nonzero(
                    (w >= cfg.min_bbox_size)
                    & (h >= cfg.min_bbox_size),
                    as_tuple=False).squeeze()
                if valid_inds.sum().item() != len(proposals):
                    proposals = proposals[valid_inds, :]
                    scores = scores[valid_inds]
                    ids = ids[valid_inds]
    
            # TODO: remove the hard coded nms type
            nms_cfg = dict(type='nms', iou_threshold=cfg.nms_thr)
            dets, keep = batched_nms(proposals, scores, ids, nms_cfg)
            return dets[:cfg.nms_post]

    RPNHead的主要包括以下函数:

    (1)__init__(),初始化函数.

    (2)_init_layers(),该函数主要实现Head的Layer的初始化,从代码可知,有3个Layer:卷积层rpn_conv,分类层rpn_cls,回归层rpn_reg.

    这3层都用到了nn.Conv2d这个类(定义在torch/nn/modules/conv.py里面),默认的第1个参数in_channels,第2个参数是out_channels,第3个参数是kernel_size.rpn_cls的out_channels是每一个点的Anchor数量乘以cls_out_channels(等于类的数量+1).rpn_reg的out_channels是Anchor的数量乘以4,因为一个BBox由4个输出决定.

    (3)init_weights(),初始化权重.

    (4)forward_single(),把FPN的1个特征图Forward,也就是输出这一层的分类和BBox预测.

    (5)loss(),计算损失函数.

    (6)_get_bboxes_single(),根据forward_single()的输出,以及每一层的Anchor,得到最终的预测结果,再进行NMS.

  • 相关阅读:
    如何学习区块链
    用Python从零开始创建区块链
    区块链入门
    什么是区块链
    localStorage使用总结
    整理vue学习笔记
    SCSS 教程
    vue — 创建vue项目
    软件开发的常见文档
    史上最全的CSS hack方式一览(转)
  • 原文地址:https://www.cnblogs.com/mstk/p/14658987.html
Copyright © 2011-2022 走看看