zoukankan      html  css  js  c++  java
  • Receptive Field Block Net for Accurate and Fast Object Detection

    Receptive Field Block Net for Accurate and Fast Object Detection

    一. 论文简介

    用于目标检测,增加感受野。

    主要做的贡献如下(可能之前有人已提出):

    1. 设计一个增大感受野的模块RFB

    二. 模块详解

    2.1 论文思路简介

    下面一张图即可说明问题,卷积模仿人的视网膜,对于远近不同的物体使用不同的卷积(网络越深,感受野越大)。

    作者设计的模块如下图所示,一图说明不用任何其他废话

    对比其他增大感受野的模块,很类似ASPP模块

    给出两种结构,RFB使用大卷积(深层使用),RFBs使用小卷积(浅层使用)

    实测确实对于目标检测有提升,同样计算量也增大了一些。


    2.2 具体实现

    2.2.1 具体实现

    class BasicConv(nn.Module):
    
        def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True):
            super(BasicConv, self).__init__()
            self.out_channels = out_planes
            if bn:
                self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=False)
                self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True)
                self.relu = nn.ReLU(inplace=True) if relu else None
            else:
                self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=True)
                self.bn = None
                self.relu = nn.ReLU(inplace=True) if relu else None
    
        def forward(self, x):
            x = self.conv(x)
            if self.bn is not None:
                x = self.bn(x)
            if self.relu is not None:
                x = self.relu(x)
            return x
    
    
    class BasicRFB(nn.Module):
    
        def __init__(self, in_planes, out_planes, stride=1, scale=0.1, map_reduce=8, vision=1, groups=1):
            super(BasicRFB, self).__init__()
            self.scale = scale
            self.out_channels = out_planes
            inter_planes = in_planes // map_reduce
    
            self.branch0 = nn.Sequential(
                BasicConv(in_planes, inter_planes, kernel_size=1, stride=1, groups=groups, relu=False),
                BasicConv(inter_planes, 2 * inter_planes, kernel_size=(3, 3), stride=stride, padding=(1, 1), groups=groups),
                BasicConv(2 * inter_planes, 2 * inter_planes, kernel_size=3, stride=1, padding=vision + 1, dilation=vision + 1, relu=False, groups=groups)
            )
            self.branch1 = nn.Sequential(
                BasicConv(in_planes, inter_planes, kernel_size=1, stride=1, groups=groups, relu=False),
                BasicConv(inter_planes, 2 * inter_planes, kernel_size=(3, 3), stride=stride, padding=(1, 1), groups=groups),
                BasicConv(2 * inter_planes, 2 * inter_planes, kernel_size=3, stride=1, padding=vision + 2, dilation=vision + 2, relu=False, groups=groups)
            )
            self.branch2 = nn.Sequential(
                BasicConv(in_planes, inter_planes, kernel_size=1, stride=1, groups=groups, relu=False),
                BasicConv(inter_planes, (inter_planes // 2) * 3, kernel_size=3, stride=1, padding=1, groups=groups),
                BasicConv((inter_planes // 2) * 3, 2 * inter_planes, kernel_size=3, stride=stride, padding=1, groups=groups),
                BasicConv(2 * inter_planes, 2 * inter_planes, kernel_size=3, stride=1, padding=vision + 4, dilation=vision + 4, relu=False, groups=groups)
            )
    
            self.ConvLinear = BasicConv(6 * inter_planes, out_planes, kernel_size=1, stride=1, relu=False)
            self.shortcut = BasicConv(in_planes, out_planes, kernel_size=1, stride=stride, relu=False)
            self.relu = nn.ReLU(inplace=False)
    
        def forward(self, x):
            x0 = self.branch0(x)
            x1 = self.branch1(x)
            x2 = self.branch2(x)
    
            out = torch.cat((x0, x1, x2), 1)
            out = self.ConvLinear(out)
            short = self.shortcut(x)
            out = out * self.scale + short
            out = self.relu(out)
    
            return out
    

    三. 参考文献

    • 原始论文
  • 相关阅读:
    wait函数和waitpid的使用和总结
    linux中sleep函数的使用和总结
    alarm()函数的使用总结
    linux定时器的实现方法
    Socket的长连接和短连接
    记录各种材质的数据
    max导出模型插件
    鸡汤 -心灵 记录
    UGUI 加载图片
    u3d udp服务器
  • 原文地址:https://www.cnblogs.com/wjy-lulu/p/13822302.html
Copyright © 2011-2022 走看看