zoukankan      html  css  js  c++  java
  • MMDetection源码解析:Faster RCNN(7)--ConvFCBBoxHead,Shared2FCBBoxHead和Shared4Conv1FCBBoxHead类

    ConvFCBBoxHead类定义在mmdetmodels oi_headsbox_headsconvfc_bbox_head.py中,其作用是对共享特征层进行卷积和全连接操作,然后在forward到BBoxHead类中,而且也继承自BBoxHead类.convfc_bbox_head.py还包含了Shared2FCBBoxHead和Shared4Conv1FCBBoxHead两个类.

    import torch.nn as nn
    from mmcv.cnn import ConvModule
    
    from mmdet.models.builder import HEADS
    from .bbox_head import BBoxHead
    
    
    @HEADS.register_module()
    class ConvFCBBoxHead(BBoxHead):
        r"""More general bbox head, with shared conv and fc layers and two optional
        separated branches.
    
        .. code-block:: none
    
                                        /-> cls convs -> cls fcs -> cls
            shared convs -> shared fcs
                                        -> reg convs -> reg fcs -> reg
        """  # noqa: W605
    
        def __init__(self,
                     num_shared_convs=0,
                     num_shared_fcs=0,
                     num_cls_convs=0,
                     num_cls_fcs=0,
                     num_reg_convs=0,
                     num_reg_fcs=0,
                     conv_out_channels=256,
                     fc_out_channels=1024,
                     conv_cfg=None,
                     norm_cfg=None,
                     *args,
                     **kwargs):
            super(ConvFCBBoxHead, self).__init__(*args, **kwargs)
            assert (num_shared_convs + num_shared_fcs + num_cls_convs +
                    num_cls_fcs + num_reg_convs + num_reg_fcs > 0)
            if num_cls_convs > 0 or num_reg_convs > 0:
                assert num_shared_fcs == 0
            if not self.with_cls:
                assert num_cls_convs == 0 and num_cls_fcs == 0
            if not self.with_reg:
                assert num_reg_convs == 0 and num_reg_fcs == 0
            self.num_shared_convs = num_shared_convs
            self.num_shared_fcs = num_shared_fcs
            self.num_cls_convs = num_cls_convs
            self.num_cls_fcs = num_cls_fcs
            self.num_reg_convs = num_reg_convs
            self.num_reg_fcs = num_reg_fcs
            self.conv_out_channels = conv_out_channels
            self.fc_out_channels = fc_out_channels
            self.conv_cfg = conv_cfg
            self.norm_cfg = norm_cfg
    
            # add shared convs and fcs
            self.shared_convs, self.shared_fcs, last_layer_dim = 
                self._add_conv_fc_branch(
                    self.num_shared_convs, self.num_shared_fcs, self.in_channels,
                    True)
            self.shared_out_channels = last_layer_dim
    
            # add cls specific branch
            self.cls_convs, self.cls_fcs, self.cls_last_dim = 
                self._add_conv_fc_branch(
                    self.num_cls_convs, self.num_cls_fcs, self.shared_out_channels)
    
            # add reg specific branch
            self.reg_convs, self.reg_fcs, self.reg_last_dim = 
                self._add_conv_fc_branch(
                    self.num_reg_convs, self.num_reg_fcs, self.shared_out_channels)
    
            if self.num_shared_fcs == 0 and not self.with_avg_pool:
                if self.num_cls_fcs == 0:
                    self.cls_last_dim *= self.roi_feat_area
                if self.num_reg_fcs == 0:
                    self.reg_last_dim *= self.roi_feat_area
    
            self.relu = nn.ReLU(inplace=False)
            # reconstruct fc_cls and fc_reg since input channels are changed
            if self.with_cls:
                self.fc_cls = nn.Linear(self.cls_last_dim, self.num_classes + 1)
            if self.with_reg:
                out_dim_reg = (4 if self.reg_class_agnostic else 4 *
                               self.num_classes)
                self.fc_reg = nn.Linear(self.reg_last_dim, out_dim_reg)
    
        def _add_conv_fc_branch(self,
                                num_branch_convs,
                                num_branch_fcs,
                                in_channels,
                                is_shared=False):
            """Add shared or separable branch.
    
            convs -> avg pool (optional) -> fcs
            """
            last_layer_dim = in_channels
            # add branch specific conv layers
            branch_convs = nn.ModuleList()
            if num_branch_convs > 0:
                for i in range(num_branch_convs):
                    conv_in_channels = (
                        last_layer_dim if i == 0 else self.conv_out_channels)
                    branch_convs.append(
                        ConvModule(
                            conv_in_channels,
                            self.conv_out_channels,
                            3,
                            padding=1,
                            conv_cfg=self.conv_cfg,
                            norm_cfg=self.norm_cfg))
                last_layer_dim = self.conv_out_channels
            # add branch specific fc layers
            branch_fcs = nn.ModuleList()
            if num_branch_fcs > 0:
                # for shared branch, only consider self.with_avg_pool
                # for separated branches, also consider self.num_shared_fcs
                if (is_shared
                        or self.num_shared_fcs == 0) and not self.with_avg_pool:
                    last_layer_dim *= self.roi_feat_area
                for i in range(num_branch_fcs):
                    fc_in_channels = (
                        last_layer_dim if i == 0 else self.fc_out_channels)
                    branch_fcs.append(
                        nn.Linear(fc_in_channels, self.fc_out_channels))
                last_layer_dim = self.fc_out_channels
            return branch_convs, branch_fcs, last_layer_dim
    
        def init_weights(self):
            super(ConvFCBBoxHead, self).init_weights()
            # conv layers are already initialized by ConvModule
            for module_list in [self.shared_fcs, self.cls_fcs, self.reg_fcs]:
                for m in module_list.modules():
                    if isinstance(m, nn.Linear):
                        nn.init.xavier_uniform_(m.weight)
                        nn.init.constant_(m.bias, 0)
    
        def forward(self, x):
            # shared part
            if self.num_shared_convs > 0:
                for conv in self.shared_convs:
                    x = conv(x)
    
            if self.num_shared_fcs > 0:
                if self.with_avg_pool:
                    x = self.avg_pool(x)
    
                x = x.flatten(1)
    
                for fc in self.shared_fcs:
                    x = self.relu(fc(x))
            # separate branches
            x_cls = x
            x_reg = x
    
            for conv in self.cls_convs:
                x_cls = conv(x_cls)
            if x_cls.dim() > 2:
                if self.with_avg_pool:
                    x_cls = self.avg_pool(x_cls)
                x_cls = x_cls.flatten(1)
            for fc in self.cls_fcs:
                x_cls = self.relu(fc(x_cls))
    
            for conv in self.reg_convs:
                x_reg = conv(x_reg)
            if x_reg.dim() > 2:
                if self.with_avg_pool:
                    x_reg = self.avg_pool(x_reg)
                x_reg = x_reg.flatten(1)
            for fc in self.reg_fcs:
                x_reg = self.relu(fc(x_reg))
    
            cls_score = self.fc_cls(x_cls) if self.with_cls else None
            bbox_pred = self.fc_reg(x_reg) if self.with_reg else None
            return cls_score, bbox_pred
    
    
    @HEADS.register_module()
    class Shared2FCBBoxHead(ConvFCBBoxHead):
    
        def __init__(self, fc_out_channels=1024, *args, **kwargs):
            super(Shared2FCBBoxHead, self).__init__(
                num_shared_convs=0,
                num_shared_fcs=2,
                num_cls_convs=0,
                num_cls_fcs=0,
                num_reg_convs=0,
                num_reg_fcs=0,
                fc_out_channels=fc_out_channels,
                *args,
                **kwargs)
    
    
    @HEADS.register_module()
    class Shared4Conv1FCBBoxHead(ConvFCBBoxHead):
    
        def __init__(self, fc_out_channels=1024, *args, **kwargs):
            super(Shared4Conv1FCBBoxHead, self).__init__(
                num_shared_convs=4,
                num_shared_fcs=1,
                num_cls_convs=0,
                num_cls_fcs=0,
                num_reg_convs=0,
                num_reg_fcs=0,
                fc_out_channels=fc_out_channels,
                *args,
                **kwargs)

    主要的函数有:

    (1) __init__():初始化函数,主要参数是各层的数量;

    (2) _add_conv_fc_branch():增加卷积或全连接层;

    (3) init_weights():初始化权重;

    (4) forward():前向传播;

    Shared2FCBBoxHead和Shared4Conv1FCBBoxHead类继承自ConvFCBBoxHead类,主要参数如下:

    (1) num_shared_convs:共享卷积层数量;

    (2) num_shared_fcs:共享全连接层数量;

    (3) num_cls_convs:分类卷积层数量;

    (4) num_cls_fcs:分类全连接层数量;

    (5) num_reg_convs:回归卷积层的数量;

    (6) num_reg_fcs:回归全连接层的数量;

    (7) fc_out_channels:全连接层后输出层的数量,默认值为1024.

    更改这些参数的值,就可以构建不同结构的模型,还是非常方便的.

  • 相关阅读:
    axios实现跨域及突破host和referer的限制
    视频测试URL地址
    微信小程序 自定义导航栏 自动获取高度 写法
    解决flex布局宽度超出时,子元素被压缩的问题
    子组件向父组件通信与父组件向子组件通信
    时间截止器
    arguments
    改变this指向&闭包特性
    ES6扩展——箭头函数
    ES6扩展——函数扩展之剩余函数
  • 原文地址:https://www.cnblogs.com/mstk/p/15120222.html
Copyright © 2011-2022 走看看