zoukankan      html  css  js  c++  java
  • SSD源码解读——网络测试

    之前,对SSD的论文进行了解读,可以回顾之前的博客:https://www.cnblogs.com/dengshunge/p/11665929.html

    为了加深对SSD的理解,因此对SSD的源码进行了复现,主要参考的github项目是ssd.pytorch。同时,我自己对该项目增加了大量注释:https://github.com/Dengshunge/mySSD_pytorch

    搭建SSD的项目,可以分成以下四个部分:

    1. 数据读取
    2. 网络搭建
    3. 损失函数的构建
    4. 网络测试

    接下来,本篇博客重点分析网络测试


    在eval.py文件中,首先需要搭建测试用的网络。此时,需要将传入的第一个参数换成"test"字符串,这是因为训练和测试阶段,网络的输出会有不同。在测试阶段,会对预测框进行nms等操作。然后是常规的加载训练模型,将网络设置成eval模式,不更新梯度。

        num_classes = len(labelmap) + 1  # +1 for background
        net = build_ssd('test', 300, num_classes)  # initialize SSD
        net.load_state_dict(torch.load(args.trained_model))
        net.eval()

    我们再来看看,测试阶段中SSD网络的不同。在ssd.py中,如果是test阶段,在类ssd()中,会初始化函数Detect()函数。并且在类SSD()的forward函数中,将坐标预测结果,经过softmax的置信度预测结果和先验锚点框传递进去,进行运算。最终输出一个tensor,shape为[batch,num_classes,top_k,5]。其中,num_classes是类别总数,对于VOC而言,为21;top_k表示最多取top_k个锚点框进行输出,论文中值为200;5表示[confidence,xmin,ymin,xmax,ymax]。

            if phase == 'test':
                self.softmax = nn.Softmax(dim=-1)
                self.detect = Detect(num_classes=self.num_classes, top_k=200,
                                     conf_thresh=0.01, nms_thresh=0.45)
            if self.phase == 'train':
                output = (loc.view(loc.size(0), -1, 4),  # [batch_size,num_priors,4]
                          conf.view(conf.size(0), -1, self.num_classes),  # [batch_size,num_priors,21]
                          self.priors)  # [num_priors,4]
            else:  # Test
                output = self.detect(
                    loc.view(loc.size(0), -1, 4),  # 位置预测
                    self.softmax(conf.view((conf.size(0), -1, self.num_classes))),  # 置信度预测
                    self.priors.cuda()  # 先验锚点框
                )

    在models/detection.py中,定义了类Detect()。首先,创建output来保存最终结果,其shape为[batch,num_classes,top_k,5],其具体含义可以看上面。然后对对置信度结果进行transpose转置,这样做的目的是方便后续计算,conf_preds的shape为[batch,num_classes,num_priors]。接着,由于网络预测出来的位置预测结果,并不是真正的坐标,需要对其结果进行解码,得到真正的坐标(其范围是[0,1]之间);然后,对每个类别进行单独计算(不包含背景), c_mask = conf_scores[cl].gt(self.conf_thresh) 表示对某一类(如bike)得到shape为[1,8732]的tensor,每个值表示预测框对该类别的置信度,通过函数gt(),得到大于置信度阈值的掩码(即为c_mask,元素组成是true或者false)。通过这个mask,就可以获得大于置信度要求的预测框(包含置信度和坐标),并通过nms操作,得到最终输出锚点框的Index,将对应的结果(置信度和坐标)保存在output中。

    class Detect(Function):
        def __init__(self, num_classes, top_k, conf_thresh, nms_thresh):
            self.num_classes = num_classes
            self.top_k = top_k
            # Parameters used in nms.
            self.nms_thresh = nms_thresh  # 非极大值抑制阈值
            self.conf_thresh = conf_thresh  # 置信度阈值
    
        def forward(self, loc_data, conf_data, prior_data):
            '''
            :param loc_data: 模型预测的锚点框位置偏差信息,shape[batch,num_priors,4]
            :param conf_data: 模型预测的锚点框置信度,[batch,num_priors,num_classes]
            :param prior_data: 先验锚点框,[num_priors,4]
            :return:最终预测结果,shape[batch,num_classes,top_k,5],其中5表示[置信度,xmin,ymin,xmax,ymax],
                top_k中前面不为0的是预测结果,后面为0是为了填充
            '''
            num = loc_data.shape[0]  # batch size
            num_priors = prior_data.shape[0]  # 8732
            output = torch.zeros(num, self.num_classes, self.top_k, 5)  # 保存结果
            conf_preds = conf_data.view(num, num_priors, self.num_classes).transpose(2, 1)  # 置信度预测,transpose是为了后续操作方便
    
            for i in range(num):
                decoded_boxes = decode(loc_data[i], prior_data, voc['variance'])  # shape:[num_priors,4],对预测锚点框进行解码
                # 对每个类别,执行nms
                conf_scores = conf_preds[i].clone()  # shape:[num_classes,num_priors]
                for cl in range(1, self.num_classes):
                    c_mask = conf_scores[cl].gt(self.conf_thresh)  # 和置信度阈值进行比较,大于为true,否则为false
                    scores = conf_scores[cl][c_mask]  # 得到置信度大于阈值的那些锚点框置信度
                    if scores.shape[0] == 0:
                        # 说明锚点框与这一类的GT框不匹配,简介说明,不存在这一类的目标
                        continue
                    l_mask = c_mask.unsqueeze(1).expand_as(decoded_boxes)
                    boxes = decoded_boxes[l_mask].view(-1, 4)  # 得到置信度大于阈值的那些锚点框
                    ids = nms(boxes, scores, self.nms_thresh, self.top_k)  # 对置信度大于阈值的那些锚点框进行nms,得到最终预测结果的index
                    output[i, cl, :len(ids)] = torch.cat((scores[ids].unsqueeze(1),
                                                          boxes[ids]), 1)  # [置信度,xmin,ymin,xmax,ymax]
            return output

    解码函数decode()在models.box_utils.py中。在训练阶段,我们对坐标进行了带方差的编码,因此,需要对坐标进行同样方式的解码。

    $$b^{cx}=d^w(var[0]*l^{cx})+d^{cx},  b^{cy}=d^h(var[1]*l^{cy})+d^{cy}$$

    $$b^w=d^wexp(var[2]*l^w),  b^h=d^hexp(var[3]*l^h)$$

    def decode(loc, priors, variances):
        '''
        对编码的坐标进行解码,返回预测框的坐标
        :param loc: 网络预测的锚点框偏差信息,shape[num_priors,4]
        :param priors: 先验锚点框,[num_priors,4]
        :return: 预测框的坐标[num_priors,4],4代表[xmin,ymin,xmax,ymax]
        '''
        boxes = torch.cat((
            priors[:, :2] + loc[:, :2] * priors[:, 2:] * variances[0],
            priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)  # [中心点x,中心点y,宽,高]
        boxes[:, :2] -= boxes[:, 2:] / 2  # xmin,ymin
        boxes[:, 2:] += boxes[:, :2]  # xmax,ymax
        return boxes

    在models.box_utils.py中,还存在着nms()函数。首先对置信度进行降序排序,取出置信度最大的前top_k个用于判断,其余的锚点框,则不加入判断中。然后,对这些预测框进行nms操作,即判断一个锚点框与其余锚点框的IOU,只保留IOU小于阈值的锚点框,排除大于阈值的锚点框,将剩余的锚点框再次循环,直至idx中不存在元素,即keep中保留的锚点框编号为最终输出的锚点框。

    def nms(boxes, scores, overlap=0.5, top_k=200):
        '''
        进行nms操作
        :param boxes: 模型预测的锚点框的坐标
        :param scores: 模型预测的锚点框对应某一类的置信度
        :param overlap:nms阈值
        :param top_k:选取前top_k个预测框进行操作
        :return:预测框的index
        '''
        keep = torch.zeros(scores.shape[0])
        if boxes.numel() == 0:  # numel()返回tensor里面所有元素的个数
            return keep
        _, idx = scores.sort(0)  # 升序排序
        idx = idx[-top_k:]  # 取得最大的top_k个置信度对应的index
    
        keep = []  # 记录最大最终锚点框的index
        while idx.numel() > 0:
            i = idx[-1]  # 取出置信度最大的锚点框的index
            keep.append(i)
            idx = idx[:-1]
            if idx.numel() == 0:
                break
            IOU = jaccard(boxes[i].unsqueeze(0), boxes[idx])  # 计算这个锚点框与其余锚点框的iou
            mask = IOU.le(overlap).squeeze(0)
            idx = idx[mask]  # 排除大于阈值的锚点框
        return torch.tensor(keep)

    接下来,还是回到eval.py函数中,继续理解网络测试代码。接下来加载测试数据,方式与之前介绍的类似。但这里的图片处理方式函数BaseTransform()并不需要进行数据增强,值需要将数据进行resize和减去均值。其余的内容,几乎一致。

        # 记载数据
        dataset = VOCDetection(args.voc_root, [('2007', 'test')], BaseTransform(300, (104, 117, 123)),
                               VOCAnnotationTransform())

    接下来,循环每张测试图片。根据每类类别,得到第i张图片的第j个类别的信息,包含检测框的置信度和坐标,其中,坐标是真实坐标(不是[0,1]之间),并将其放入变量all_boxes中。变量all_boxes是一个类似二维矩阵的变量,  all_boxes = [[[] for _ in range(num_images)] for _ in range(len(labelmap) + 1)] ,其中,列表示每个类别,行表示每张图片的检测信息。这样,就能得到所有图片所有类别的检测信息了,就可以用于下面的准确率、召回率和mAP计算了。

        for i in range(num_images):
            img, gt, h, w = dataset.pull_item(i)
            img = img.unsqueeze(0)
            if torch.cuda.is_available():
                img = img.cuda()
    
            detections = net(img)  # 得到结果,shape[1,21,200,5]
            for j in range(1, detections.shape[1]):  # 循环计算每个类别
                dets = detections[0, j, :]  # shape[200,5],表示每个类别中最多200个锚点框,每个锚点框有5个值[conf,xmin,ymin,xmax,ymax]
                mask = dets[:, 0].gt(0.).expand(5, dets.shape[0]).t()  # 取出置信度大于0的情况.因为可能会出现实际有值的锚点框少于200个
                dets = torch.masked_select(dets, mask).view(-1, 5)  # 取出这些锚点框
                if dets.shape[0] == 0:
                    # 说明该图片不存在该类别
                    continue
                boxes = dets[:, 1:]  # 取出锚点框坐标
                # 计算出真实坐标
                boxes[:, 0] *= w
                boxes[:, 1] *= h
                boxes[:, 2] *= w
                boxes[:, 3] *= h
    
                scores = dets[:, 0].numpy()
                # np.newaxis增加一个新轴
                # 注意[xmin,ymin,xmax,ymax,conf]
                cls_dets = np.hstack((boxes.numpy(), scores[:, np.newaxis])).astype(np.float32, copy=False)
    
                all_boxes[j][i] = cls_dets

    利用上面得到的all_boxes信息,进入到测试函数的关键部分,函数evaluate_detections()。首先,将所有检测结果已文本的形式保存下来,方便读取和调用。然后再执行计算评价指标的函数。

    def evaluate_detections(all_boxes, dataset):
        write_voc_results_file(all_boxes, dataset)  # 将所有检测结果写成文本,保存下来
        do_python_eval(use_07=False)

    在函数write_voc_results_file()中,实现的功能就是根据某一类别和测试图片的index,读取变量all_boxes中的检测信息,将其按照[图片名,置信度,xmin,ymin,xmax,ymax]的形式,写入文本中。如VOC,我们会得到20个文本文件,不同文本表示不同的类别;同一文本下,包含了所有测试图片对该类别的检测结果。

    def write_voc_results_file(all_boxes, dataset):
        # 将检测结果按照每类写成文本,方便后面读取结果
        for cls_ind, cls in enumerate(labelmap):
            print('Writing {:s} VOC results file'.format(cls))
            if not os.path.exists(args.save_det_result):
                os.mkdir(args.save_det_result)
            filename = os.path.join(args.save_det_result, 'det_%s.txt' % (cls))
            with open(filename, 'w') as f:
                for im_ind, index in enumerate(dataset.ids):  # dataset.ids:[path,图片名]
                    dets = all_boxes[cls_ind + 1][im_ind]  # 测试的时候,图片是按这个顺序读取的
                    if dets == []:
                        continue
                    for k in range(dets.shape[0]):
                        f.write('{:s} {:.3f} {:.1f} {:.1f} {:.1f} {:.1f}
    '.
                                format(index[1], dets[k, -1],
                                       dets[k, 0] + 1, dets[k, 1] + 1,
                                       dets[k, 2] + 1, dets[k, 3] + 1))

    当将信息保存完后,就进入到do_python_eval()函数中。参数use_07表示(true)使用2007的11点计算mAP方式还是(false)2010年的mAP计算方式。首先按照每个类别,读取上述保存检测结果的文件。进入到关键函数voc_eval()中,得到召回率、准确率和AP值。

    def do_python_eval(use_07=True):
        aps = []  # 保存所有类别的AP
        for i, cls in enumerate(labelmap):
            filename = os.path.join(args.save_det_result, 'det_%s.txt' % (cls))  # 读取这一类别的检测结果,对应上刚刚保存的结果
            rec, prec, ap = voc_eval(filename,
                                     os.path.join(args.voc_root, 'VOC2007', 'ImageSets', 'Main', 'test.txt'),
                                     cls,
                                     args.cachedir,
                                     ovthresh=0.5,
                                     use_07_metric=use_07)
            aps += [ap]
            print('AP for {} = {:.4f}'.format(cls, ap))

    在函数voc_eval()中,首先,会读取所有测试图片相关的xml文件,读取的方式与一开始介绍的数据读取类似,将所有信息保存在字典recs中,其中,key为图片名字,value是xml文件信息。并将该resc信息保存下来,方便以后继续读取。因为这里面都是真实信息,变动相对较少。然后根据某一类别,在字典recs中读取每个图片,将该类别的信息提取出来,构成字典class_recs,其中key为某一类下的图片名称,value为GT框坐标、是否难例和是否已经检测过。上面是处理真实信息,接下来,处理预测信息。读取某一类的预测结果文件,该文件在函数write_voc_results_file()中形成的。然后对文件内容进行分割,得到文件名,置信度,预测框等集合,并根据置信度,对3个集合进行降序排列。按顺序读取每个预测框,计算该预测框与这张图所有GT框的IOU。当IOU大于阈值且该GT框没有匹配过时,tp的相应位置置1,否则fp的相应位置置0。由此可以,tp和fp互斥。上述的tp和fp并不是true positive和false positive,需要进行行累加,并除以预测框总数或者GT框总数,才能得到召回率和准确率。之后通过计算,得到给类别的AP值。

    def voc_eval(detpath,  # 某一类别下检测结果,每一行由文件名,置信度和检测坐标组成
                 imagesetfile,  # 包含所有测试图片的文件
                 classname,  # 需要检测的类别
                 cachedir,  # 缓存GT框的pickle文件
                 ovthresh=0.5,  # IOU阈值
                 use_07_metric=True):
        '''
        假设检测结果在detpath.format(classname)下
        假设GT框坐标在annopath.format(imagename)
        假设imagesetfile每行仅包含一个文件名
        缓存所有GT框
        '''
        if not os.path.isdir(cachedir):
            os.mkdir(cachedir)
        cachefile = os.path.join(cachedir, 'annots.pkl')
        # 读取所有检测图片
        with open(imagesetfile, 'r') as f:
            lines = f.readlines()
        imagenames = [x.strip() for x in lines]  # 每张测试图片的名字
    
        # 下面代码是创建缓存文件,方便读取
        if not os.path.isfile(cachefile):
            # 不存在GT框缓存文件,则创建
            recs = {}  # key为图片名字,value为该图片下所有检测信息
            for i, imagename in enumerate(imagenames):
                recs[imagename] = parse_rec(
                    os.path.join(args.voc_root, 'VOC2007', 'Annotations',
                                 '%s.xml') % (imagename))  # 返回该图片下所有xml信息,包含所有目标
                if i % 100 == 0:
                    print('Reading annotation for {:d}/{:d}'.format(
                        i + 1, len(imagenames)))
            # 保存下来,方便下次读取
            print('Saving cached annotations to {:s}'.format(cachefile))
            with open(cachefile, 'wb') as f:
                pickle.dump(recs, f)
        else:
            # 如果已经存在该文件,则加载回来即可
            with open(cachefile, 'rb') as f:
                recs = pickle.load(f)
    
        # 为这一类提取GT框
        class_recs = {}
        npos = 0  # 这一类别的gt框总数
        for imagename in imagenames:
            R = [obj for obj in recs[imagename] if obj['name'] == classname]  # 提取某张测试图片下该类别的信息
            bbox = np.array([x['bbox'] for x in R])  # GT框坐标
            difficult = np.array([x['difficult'] for x in R]).astype(np.bool)  # 元素为true或者false,true表示难例
            det = [False] * len(R)  # 长度为len(R)的list,用于表示该GT框是否已经匹配过,len(R)可以理解为该测试图片下,该类别的数量
            npos = npos + sum(~difficult)  # 只选取非难例,计算非难例的个数,可以理解为GT框的个数
            class_recs[imagename] = {'bbox': bbox,
                                     'difficult': difficult,
                                     'det': det}
    
        # 读取这一类的检测结果
        with open(detpath, 'r') as f:
            lines = f.readlines()
    
        if any(lines) == 1:  # 不为空
            splitlines = [x.strip().split(' ') for x in lines]
            image_ids = [x[0] for x in splitlines]  # 图片名称集合,包含重复的
            confidence = np.array([float(x[1]) for x in splitlines])  # 置信度集合
            BB = np.array([[float(z) for z in x[2:]] for x in splitlines])  # 检测框集合
    
            # 根据置信度,降序排列
            sorted_ind = np.argsort(-confidence)  # 降序排名
            sorted_scores = np.sort(-confidence)  # 降序排列
            BB = BB[sorted_ind, :]  # 检测框根据置信度进行降序排列
            image_ids = [image_ids[x] for x in sorted_ind]
    
            nd = len(image_ids)  # 检测的目标总数
            tp = np.zeros(nd)  # 记录tp
            fp = np.zeros(nd)  # 记录fp,与tp互斥
    
            for d in range(nd):  # 循环每个预测框
                R = class_recs[image_ids[d]]  # 该图片下的真实信息
                bb = BB[d, :].astype(float)  # 预测框的坐标
                ovmax = -np.inf  # 预测框与GT框的IOU
                BBGT = R['bbox'].astype(float)  # GT框的坐标
    
                if BBGT.size > 0:
                    # 计算多个GT框与一个预测框的IOU,选择最大IOU
                    # 下面是计算IOU的流程
                    ixmin = np.maximum(BBGT[:, 0], bb[0])
                    iymin = np.maximum(BBGT[:, 1], bb[1])
                    ixmax = np.minimum(BBGT[:, 2], bb[2])
                    iymax = np.minimum(BBGT[:, 3], bb[3])
                    iw = np.maximum(ixmax - ixmin, 0.)
                    ih = np.maximum(iymax - iymin, 0.)
                    inters = iw * ih
                    uni = ((bb[2] - bb[0]) * (bb[3] - bb[1]) +
                           (BBGT[:, 2] - BBGT[:, 0]) *
                           (BBGT[:, 3] - BBGT[:, 1]) - inters)
                    overlaps = inters / uni
                    ovmax = np.max(overlaps)  # 得到该预测框与GT框最大的IOU值
                    jmax = np.argmax(overlaps)  # 得到该预测框对应最大IOU的GT框的index
    
                if ovmax > ovthresh:
                    # 当IOU大于阈值,才有机会判断为正例
                    # 判断为fp有两种情况:
                    # 1.该GT框被置信度高的预测框匹配过
                    # 2.IOU小于阈值
                    if not R['difficult'][jmax]:
                        # 该GT框要求之前没有匹配过
                        # 由于置信度是降序排序的,GT框只匹配置信度最高的,其余认为是FP
                        tp[d] = 1.
                        R['det'][jmax] = 1
                    else:
                        fp[d] = 1.
                else:
                    fp[d] = 1
    
            # 计算recall,precision
            fp = np.cumsum(fp)  # shape:[1,nd]
            tp = np.cumsum(tp)  # shape:[1,nd]
            rec = tp / float(npos)  # 召回率
            prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps)  # 准确率,防止除0
            ap = voc_ap(rec, prec, use_07_metric)
        else:
            rec = -1.
            prec = -1.
            ap = -1.
    
        return rec, prec, ap

    根据召回率,准确率,就可以计算该类别的AP值了。计算AP值有两种方法,第一种是2007年的11点计算,该方法给召回率设定11个阈值,如 np.arange(0., 1.1, 0.1) ,计算大于阈值情况下的最大准确率,获得11个准确率后,求平均,就得到了AP值;第二种方法是2010年提出的,首先将PR曲线进行平滑,第i-i个点去第i-1个和第i个点的最大值,将PR曲边变成了递减曲线,然后计算该递减曲线下的面积,得到AP值。

    def voc_ap(rec, prec, use_07_metric=True):
        '''
        根据召回率和准确率,计算AP
        AP计算有两种方式:1.11点计算;2.最大面积计算
        :param rec: [1,num_all_detect]
        :param prec: [1,num_all_detect]
        '''
        if use_07_metric:
            # 旧版,11点计算
            ap = 0.
            for t in np.arange(0., 1.1, 0.1):
                # 给召回率设定阈值,统计当召回率大于阈值的情况下,最大的准确率
                if np.sum(rec >= t) == 0:
                    # 说明召回率没有比t更大
                    p = 0
                else:
                    p = np.max(prec[rec >= t])
                ap = ap + p / 11
        else:
            # 增加两个数字,是为了方便计算
            mrec = np.concatenate(([0.], rec, [1.]))  # shape:[1,num_all_detect+2]
            mpre = np.concatenate(([0.], prec, [0.]))
    
            # 计算最大面积
            for i in range(mpre.size - 1, 0, -1):
                # 取右边的最大值
                mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
    
            # 得到与前面的数值不一样的index,可以理解成,计算面积时的边长
            i = np.where(mrec[1:] != mrec[:-1])[0]
            ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])  # 计算面积,长*宽
        return ap

    至此,SSD的网络检测代码已经解读完成。

  • 相关阅读:
    【Java EE 学习 81】【CXF框架】【CXF整合Spring】
    【Java EE 学习 80 下】【调用WebService服务的四种方式】【WebService中的注解】
    【Java EE 学习 80 上】【WebService】
    【Java EE 学习 79 下】【动态SQL】【mybatis和spring的整合】
    【Java EE 学习 79 上】【mybatis 基本使用方法】
    【Java EE 学习 78 下】【数据采集系统第十天】【数据采集系统完成】
    【Java EE 学习 78 中】【数据采集系统第十天】【Spring远程调用】
    【Java EE 学习 78 上】【数据采集系统第十天】【Service使用Spring缓存模块】
    【Java EE 学习 77 下】【数据采集系统第九天】【使用spring实现答案水平分库】【未解决问题:分库查询问题】
    【Java EE 学习 77 上】【数据采集系统第九天】【通过AOP实现日志管理】【通过Spring石英调度动态生成日志表】【日志分表和查询】
  • 原文地址:https://www.cnblogs.com/dengshunge/p/11991545.html
Copyright © 2011-2022 走看看