zoukankan      html  css  js  c++  java
  • SSD源码解读——网络搭建

    之前,对SSD的论文进行了解读,可以回顾之前的博客:https://www.cnblogs.com/dengshunge/p/11665929.html

    为了加深对SSD的理解,因此对SSD的源码进行了复现,主要参考的github项目是ssd.pytorch。同时,我自己对该项目增加了大量注释:https://github.com/Dengshunge/mySSD_pytorch

    搭建SSD的项目,可以分成以下三个部分:

    1. 数据读取
    2. 网络搭建;
    3. 损失函数的构建
    4. 网络测试

    接下来,本篇博客重点分析网络搭建


    该部分整体比较简单,思路也很清晰。

    首先,在train.py中,网络搭建的函数入口是函数build_ssd(),该函数需要传入以下几个参数:"train"或者"test"字符串、图片尺寸、类别数。其中,"train"或者"test"字符串用于区分该网络是用于训练还是测试,这两个阶段的网络有些许不同,本文主要将训练阶段的网络;而类别数需要加上背景,对于VOC而言,有20个类别,加上1个背景,即类别数是21。

    ssd_net = build_ssd('train', voc['min_dim'], voc['num_classes'])

    这里,先放一张SSD的网络结构图,可以看出,SSD网络是有3部分组成的,vgg主干网络,新增网络(Conv6之后的层)和用于检测的头部网络(Extra Feature Layers)。

    接着,在ssd.py中,首先定了一个参数,如下所示。这里主要以SSD300为例。这些参数有什么用呢?字典base的参数指的是用于搭建VGG主干网络输出通道数,其中“M”表示需要进行maxpooling;字典extras的参数同样表示新增层的输出通道数,其中“S”表示需要stride=2的降采样;字典mbo的参数表示用于特征融合的层中,每个层对应未知(x,y)的锚点框数量,在SSD300中,使用了6个层进行特征融合,如Conv_4层中,每个位置使用4个锚点框进行预测。

    base = {
        '300': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M',
                512, 512, 512],  # M表示maxpolling
        '512': [],
    }
    extras = {
        '300': [256, 'S', 512, 128, 'S', 256, 128, 256, 128, 256],  # S表示stride=2
        '512': [],
    }
    mbox = {
        '300': [4, 6, 6, 6, 4, 4],  # 每个特征图的每个点,对应锚点框的数量
        '512': [],
    }

    当定义完需要使用到的参数后,可以进行如具体搭建的环节。函数build_ssd()的定义如下所示。利用函数multibox()来构建SSD网络的各个部分,分别是VGG主干网络,新增层和用于检测的头部网络(或许可以理解为分类头和回归头)。而VGG主干网络是通过函数vgg()来实现,新增层是通过函数add_extras()来实现,而函数multibox()则搭建用于检测的头部网络。最后用这些层来初始化类SSD。

    def build_ssd(phase, size=300, num_class=21):
        if phase != 'test' and phase != 'train':
            raise ("ERROR: Phase: " + phase + " not recognized")
        base_, extras_, head_ = multibox(vgg(base[str(size)]),
                                         add_extras(extras[str(size)], in_channels=1024),
                                         mbox[str(size)],
                                         num_class)
        return SSD(phase, size, base_, extras_, head_, num_class)

    我们来看一下VGG主干网络是如何搭建的。函数vgg()需要将上述的base字典传入进去,根据base字典,来搭建卷积层和池化层。作者对vgg网络进行了改进,即将fc6和fc7更改成conv6和conv7。值得留意的是,在conv6中,使用了空点卷积,dilation=6,增大感受野。在SSD论文的最后,也讨论了空洞卷积对结果有好的影响。最后,将这些卷积层和池化层放入list中,并返回这个list。

    def vgg(cfg=base['300'], batch_norm=False):
        '''
        该函数来源于torchvision.models.vgg19()中的make_layers()
        '''
        layers = []
        in_channels = 3
    
        # vgg主体部分,到论文的conv6之前
        for v in cfg:
            if v == 'M':
                # ceil_mode是向上取整
                layers += [nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)]
            else:
                conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
                if batch_norm:
                    layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
                else:
                    layers += [conv2d, nn.ReLU(inplace=True)]
                in_channels = v
        pool5 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True)
    
        conv6 = nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6)
        conv7 = nn.Conv2d(1024, 1024, kernel_size=1)
        layers += [pool5, conv6, nn.ReLU(inplace=True), conv7, nn.ReLU(inplace=True)]
    
        return layers

    接下来,我们来了解一下SSD在vgg中新增层,即conv7之后的网络层。同样,函数add_extras()需要传入字典extras,来构建网络层。这里可以留意一下,kernel_size的写法,(1,3)为一个元祖tuple,flag来控制取哪个值,即可变换使用3*3或者1*1的卷积核,减少代码的冗余。最后将新构建的层存入list中,并返回这个list。

    def add_extras(cfg=extras['300'], in_channels=1024):
        '''
        完成SSD后半部分的网络构建,即作者新加上去的网络,从conv7之后到conv11_2
        '''
        layers = []
        flag = False  # 交替控制卷积核,使用1*1或者使用3*3
        for k, v in enumerate(cfg):
            if in_channels != 'S':
                if v == 'S':
                    layers += [nn.Conv2d(in_channels, cfg[k + 1], kernel_size=(1, 3)[flag], stride=2, padding=1)]
                else:
                    layers += [nn.Conv2d(in_channels, v, kernel_size=(1, 3)[flag])]
                flag = not flag
            in_channels = v
        return layers

    当有了vgg的主干网络和新增层后,可以将某些层进行特征融合和预测了。这里,就需要使用到函数multibox()。需要将vgg主干网络和新增层的list、字典mbox和类别数传入函数中。首先,函数multibox()会创建两个list,用于保存位置回归的层和置信度的层。对于每个用于融合的特征层,会分成两部分,一个用于回归,使用3*3的卷积,输出通道数是cfg[k] * 4,其中cfg[k]表示每个位置上锚点框的数量,4表示[x_min,y_min,x_max,y_max];另外一个用于类别的判断,也是使用3*3的卷积,输出通道数是cfg[k] * num_class,表示每个锚点框判断其属于哪一个类别,在voc中,num_class=21(包含背景)。可以理解成将此特征层分成了分类头和回归头,每个锚点框会输出4个坐标和21个类别置信度。最后将vgg主干网络、新增层、分类头和回归头返回。

    def multibox(vgg, extra_layers, cfg, num_class):
        '''
        返回vgg网络,新增网络,位置网络和置信度网络
        '''
        loc_layers = []  # 判断位置
        conf_layers = []  # 判断置信度
    
        vgg_source = [21, -2]  # 21表示conv4_3的序号,-2表示conv7的序号
        for k, v in enumerate(vgg_source):
            # vgg[v]表示需要提取的特征图
            # cfg[k]代表该特征图下每个点对应的锚点框数量
            loc_layers += [nn.Conv2d(vgg[v].out_channels, cfg[k] * 4, kernel_size=3, padding=1)]
            conf_layers += [nn.Conv2d(vgg[v].out_channels, cfg[k] * num_class, kernel_size=3, padding=1)]
    
        for k, v in enumerate(extra_layers[1::2], 2):
            # [1::2]表示,从第1位开始,步长为2
            # 这么做的目的是,新增加的层,都包含2层卷积,需要提取后面那层卷积的结果作为特征图
            loc_layers += [nn.Conv2d(v.out_channels, cfg[k] * 4, kernel_size=3, padding=1)]
            conf_layers += [nn.Conv2d(v.out_channels, cfg[k] * num_class, kernel_size=3, padding=1)]
    
        return vgg, extra_layers, (loc_layers, conf_layers)

    函数multibox()返回的各个层,用于初始化类SSD。首先,由于因为“train”阶段和“test”阶段是有点区别的,本节依然主要将“train”阶段,因此,需要传入phase参数,参数只能是两个值(train,test)。函数PriorBox()的作用是来创建先验锚点框,返回的shape为[8732,4],其中具有8732个锚点框,4表示每个锚点框的坐标[中心点x,中心点y,宽,高],这里的坐标值有点不太一样。由于传入的网络层是以list列表的形式,因此,用nn.ModuleList()将其转换为pytorch的网络结构。

    接下来看类SSD中的函数forward(),用于前向推理。按顺序对输入图片进行处理,在conv4中,需要对特征图进行L2正则化。并将用于特征融合的特征图存在放sources中。在得到5个用于融合的特征图后,将这些特征图输入到分类头和回归头中,每个特征图对应各自的分类头和回归头。这里注意一下,分类头或者回归头卷积后,使用了permute()函数。该函数的作用是交换维度,原本的维度是[batch_size,channel,height,weight],交换维度后变成了[batch_size,height,weight,channel],这样做的目的是方便后续的处理。将处理后的结果保存在loc和conf这两个List中。后续接着对loc和conf进行变换,利用view()函数,最终,loc的shape为[batch_size,8732*4],conf的shape为[batch_size,8732*21]。

    最后,将loc和conf这两个List又变换维度,返回出去,用于计算loss损失函数(感觉这么多变换,有点重复呀,应该可以省略一部分)。"train"阶段和"test"阶段返回的结果类似,其中不同点是,在test阶段,置信度需要经过softmax。

    class SSD(nn.Module):
        '''
        构建SSD的主函数,将base(vgg)、新增网络和位置网络与置信度网络组合起来
        '''
    
        def __init__(self, phase, size, base, extras, head, num_classes):
            super(SSD, self).__init__()
            self.phase = phase
            self.num_classes = num_classes
            self.priors = torch.Tensor(PriorBox(voc))
            self.size = size
    
            # SSD网络
            self.vgg = nn.ModuleList(base)
            # 对conv4_3的特征图进行L2正则化
            self.L2Norm = L2Norm(512, 20)
    
            self.extras = nn.ModuleList(extras)
            self.loc = nn.ModuleList(head[0])
            self.conf = nn.ModuleList(head[1])
    
            if phase == 'test':
                self.softmax = nn.Softmax(dim=-1)
                self.detect = Detect(num_classes=self.num_classes, top_k=200,
                                     conf_thresh=0.01, nms_thresh=0.45)
    
        def forward(self, x):
            sources = []  # 保存特征图
            loc = []  # 保存每个特征图进行位置网络后的信息
            conf = []  # 保存每个特征图进行置信度网络后的信息
    
            # 处理输入至conv4_3
            for k in range(23):
                x = self.vgg[k](x)
    
            # 对conv4_3进行L2正则化
            s = self.L2Norm(x)
            sources.append(s)
    
            # 完成vgg后续的处理
            for k in range(23, len(self.vgg)):
                x = self.vgg[k](x)
            sources.append(x)
    
            # 使用新增网络进行处理
            for k, v in enumerate(self.extras):
                x = F.relu(v(x), inplace=True)
                if k % 2 == 1:
                    sources.append(x)
    
            # 将特征图送入位置网络和置信度网络
            # l(x)或者c(x)的shape为[batch_size,channel,height,weight],使用了permute后,变成[batch_size,height,weight,channel]
            # 这样做应该是为了方便后续处理
            for (x, l, c) in zip(sources, self.loc, self.conf):
                loc.append(l(x).permute(0, 2, 3, 1).contiguous())
                conf.append(c(x).permute(0, 2, 3, 1).contiguous())
    
            # 进行格式变换
            loc = torch.cat([o.view(o.size(0), -1) for o in loc], 1)  # [batch_size,34928],锚点框的数量8732*4=34928
            conf = torch.cat([o.view(o.size(0), -1) for o in conf], 1)
    
            if self.phase == 'train':
                output = (loc.view(loc.size(0), -1, 4),  # [batch_size,num_priors,4]
                          conf.view(conf.size(0), -1, self.num_classes),  # [batch_size,num_priors,21]
                          self.priors)  # [num_priors,4]
            else:  # Test
                output = self.detect(
                    loc.view(loc.size(0), -1, 4),  # 位置预测
                    self.softmax(conf.view((conf.size(0), -1, self.num_classes))),  # 置信度预测
                    self.priors.cuda()  # 先验锚点框
                )
    
            return output

    在上面类SSD中,提及到了先验锚点框的构建函数PriorBox(),这个函数在models/prior_box.py中。首先,根据用于融合的特征图尺寸和product()函数,生成一系列的点,如(0,0),(0,1),(0,2)等。然后根据这些像素点位置,偏移0.5作为锚点框的中心点,即cx和cy,并将其归一化。然后计算论文中的$s_k$和${s_k}'$,对应s_k和s_k_prime,先计算$a_r=1$的情况,再计算其余$a_r$的情况。此时,mean的shape为[1,34928],因此,需要使用view()函数,将其切割出来,变成[8732,4]。记得,这里的锚点框的坐标是[中心点x,中心点y,宽,高]。

    def PriorBox(cfg):
        '''
        为所有特征图生成预设的锚点框,返回所有生成的锚点框,尺寸为[8732,4],
        每行表示[中心点x,中心点y,宽,高]
        '''
        image_size = cfg['min_dim']  # 300
        feature_maps = cfg['feature_maps']  # [38, 19, 10, 5, 3, 1],特征图尺寸
        steps = cfg['steps']  # [8, 16, 32, 64, 100, 300]
        min_sizes = cfg['min_sizes']  # [30, 60, 111, 162, 213, 264]
        max_sizes = cfg['max_sizes']  # [60, 111, 162, 213, 264, 315]
        aspect_ratios = cfg['aspect_ratios']  # [[2], [2, 3], [2, 3], [2, 3], [2], [2]]
    
        mean = []
        # 为所有特征图生成锚点框
        for k, f in enumerate(feature_maps):
            # product(list1,list2)的作用是依次取出list1中的每1个元素,与list2中的每1个元素,
            # 组成元组,然后,将所有的元组组成一个列表,返回
            # 而这里使用了repeat,说明1个list重复2次
            for i, j in product(range(f), repeat=2):
                f_k = image_size / steps[k]
                # 计算中心点,这里的j是沿x方向变化的
                cx = (j + 0.5) / f_k
                cy = (i + 0.5) / f_k
    
                # aspect_ratio=1有两种情况,s_k=s_k,s_k=sqrt(s_k*s_(k+1))
                s_k = min_sizes[k] / image_size
                mean += [cx, cy, s_k, s_k]
    
                s_k_prime = sqrt(s_k * (max_sizes[k] / image_size))
                mean += [cx, cy, s_k_prime, s_k_prime]
    
                # 剩余的aspect_ratio
                for ar in aspect_ratios[k]:
                    mean += [cx, cy, s_k * sqrt(ar), s_k / sqrt(ar)]
                    mean += [cx, cy, s_k / sqrt(ar), s_k * sqrt(ar)]
    
        # 此时的mean是1*34928的list,要4个数就分割出来,所以需要用view,从而变成[8732,4],即有8732个锚点框
        output = torch.Tensor(mean).view(-1, 4)
        if cfg['clip']:
            # 对每个元素进行截断限制,限制为[0,1]之间
            output.clamp_(min=0, max=1)
        return output

    最后,类SSD中还对conv4的特征层使用了L2正则化,该函数在models/l2norm.py中。在函数forwand()中,按每个通道对其值进行L2正则化,即除以通道的平方根来实现归一化。

    class L2Norm(nn.Module):
        '''
        对conv4_3进行l2归一化
        '''
    
        def __init__(self, n_channels, scale):
            super(L2Norm, self).__init__()
            self.n_channels = n_channels
            self.gamma = scale
            self.eps = 1e-10
            self.weight = nn.Parameter(torch.Tensor(self.n_channels))  # n_channels个随机数
            self.reset_parameters()
    
        def reset_parameters(self):
            # 使用gamma来填充weight的每个值
            nn.init.constant_(self.weight, self.gamma)
    
        def forward(self, x):
            # 按通道进行求值
            norm = x.pow(2).sum(dim=1, keepdim=True).sqrt() + self.eps  # [1,1,38,38]
            x = torch.div(x, norm)
            # 将weight通过3个unsqueeze展开成[1,512,1,1],然后通过expand_as进行扩展,形成[1,512,38,38]
            out = self.weight.unsqueeze(0).unsqueeze(2).unsqueeze(3).expand_as(x) * x
            return out

    至此,SSD的网络搭建过程已经完成了,通过类SSD的forward()函数,即能返回预测框的坐标和类别置信度,以此可以计算损失函数。 

     

  • 相关阅读:
    深入理解javascript的this关键字
    很简单的JQuery网页换肤
    有关垂直居中
    层的半透明实现方案
    常用meta整理
    web前端页面性能优化小结
    关于rem布局以及sprit雪碧图的移动端自适应
    mysql入过的坑
    日期格式化函数
    基于iframe父子页面传值的方法。
  • 原文地址:https://www.cnblogs.com/dengshunge/p/11937828.html
Copyright © 2011-2022 走看看