zoukankan      html  css  js  c++  java
  • 统一 map 测试标准

    使用不同的深度学习框架,以及训练不同的目标检测模型时,其各源码map测试接口的实现都不一样。为了方便统一对比不同模型的优劣,使用如下方法进行各个模型map的计算

    1、制作相同的测试集,例如 valid.txt 文件,里面保存需要测试图片的路径

    2、检测识别 valid.txt 文件里的图片数据,将测试结果保存为txt文件

    比如将测试结果保存到results文件夹下,每个目标类别的识别结果保存为一个txt文件,文件名为目标类别名。例如,文件夹结构如下
    results
    --person.txt
    --dog.txt
    --car.txt
    每个txt文件里的内容保存格式为:图片名称,类别得分,检测框的坐标(xyxy);例如 person.txt 文件如下,其中每个bbox保存一行
      train4656.jpg 0.149583 1047 550 1081 638
      train4656.jpg 0.133522 1270 502 1289 550
      train4656.jpg 0.124698 1231 494 1256 553
      train4183.jpg 0.016136 1898 380 1920 464
      train1199.jpg 0.004282 1462 478 1534 677
      train1199.jpg 0.003982 655 457 723 611

    3、使用如下代码,计算map

    calc_map.py

    from voc_eval import voc_eval
    import os
    
    current_path = os.getcwd()
    results_path = current_path+"/results"
    sub_files = os.listdir(results_path)
    
    mAP = []
    for i in range(len(sub_files)):
        class_name = sub_files[i].split(".txt")[0]
        rec, prec, ap = voc_eval('./results/{}.txt', './Annotations/{}.xml', './data/valid.txt', class_name, '.')
        print("AP for {} = {} ".format(class_name, ap))
        mAP.append(ap)
    mAP = tuple(mAP)
    
    print("***************************")
    
    count = 0
    for ap in mAP:
        if ap > 0:
            count += ap
    
    print("Mean AP = {}".format( float(count)/len(mAP)) )

    voc_eval.py

    # --------------------------------------------------------
    # Fast/er R-CNN
    # Licensed under The MIT License [see LICENSE for details]
    # Written by Bharath Hariharan
    # --------------------------------------------------------
    
    import xml.etree.ElementTree as ET
    import os
    import pickle
    import numpy as np
    
    def parse_rec(filename):
        """ Parse a PASCAL VOC xml file """
        tree = ET.parse(filename)
        objects = []
        for obj in tree.findall('object'):
            obj_struct = {}
            obj_struct['name'] = obj.find('name').text
            # obj_struct['pose'] = obj.find('pose').text
            # obj_struct['truncated'] = int(obj.find('truncated').text)
            obj_struct['difficult'] = 0 #int(obj.find('difficult').text)
            bbox = obj.find('bndbox')
            obj_struct['bbox'] = [int(bbox.find('xmin').text),
                                  int(bbox.find('ymin').text),
                                  int(bbox.find('xmax').text),
                                  int(bbox.find('ymax').text)]
            objects.append(obj_struct)
    
        return objects
    
    def voc_ap(rec, prec, use_07_metric=False):
        """ ap = voc_ap(rec, prec, [use_07_metric])
        Compute VOC AP given precision and recall.
        If use_07_metric is true, uses the
        VOC 07 11 point method (default:False).
        """
        if use_07_metric:
            # 11 point metric
            ap = 0.
            for t in np.arange(0., 1.1, 0.1):
                if np.sum(rec >= t) == 0:
                    p = 0
                else:
                    p = np.max(prec[rec >= t])
                ap = ap + p / 11.
        else:
            # correct AP calculation
            # first append sentinel values at the end
            mrec = np.concatenate(([0.], rec, [1.]))
            mpre = np.concatenate(([0.], prec, [0.]))
    
            # compute the precision envelope
            for i in range(mpre.size - 1, 0, -1):
                mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
    
            # to calculate area under PR curve, look for points
            # where X axis (recall) changes value
            i = np.where(mrec[1:] != mrec[:-1])[0]
    
            # and sum (Delta recall) * prec
            ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
        return ap
    
    def voc_eval(detpath,
                 annopath,
                 imagesetfile,
                 classname,
                 cachedir,
                 ovthresh=0.5,
                 use_07_metric=False):
        """rec, prec, ap = voc_eval(detpath,
                                    annopath,
                                    imagesetfile,
                                    classname,
                                    [ovthresh],
                                    [use_07_metric])
    
        Top level function that does the PASCAL VOC evaluation.
    
        detpath: Path to detections
            detpath.format(classname) should produce the detection results file.
        annopath: Path to annotations
            annopath.format(imagename) should be the xml annotations file.
        imagesetfile: Text file containing the list of images, one image per line.
        classname: Category name (duh)
        cachedir: Directory for caching the annotations
        [ovthresh]: Overlap threshold (default = 0.5)
        [use_07_metric]: Whether to use VOC07's 11 point AP computation
            (default False)
        """
        # assumes detections are in detpath.format(classname)
        # assumes annotations are in annopath.format(imagename)
        # assumes imagesetfile is a text file with each line an image name
        # cachedir caches the annotations in a pickle file
    
        # first load gt
        if not os.path.isdir(cachedir):
            os.mkdir(cachedir)
        cachefile = os.path.join(cachedir, 'annots.pkl')
        # read list of images
        with open(imagesetfile, 'r') as f:
            lines = f.readlines()
        imagenames = [x.strip().split(".jpg")[0] for x in lines]
        imagenames = [os.path.basename(x) for x in imagenames]
    
        if not os.path.isfile(cachefile):
            # load annots
            recs = {}
            for i, imagename in enumerate(imagenames):
                recs[imagename] = parse_rec(annopath.format(imagename))
                if i % 100 == 0:
                    print('Reading annotation for {:d}/{:d}'.format(
                        i + 1, len(imagenames)))
            # save
            print('Saving cached annotations to {:s}'.format(cachefile))
            with open(cachefile, 'wb') as f:
                pickle.dump(recs, f)
        else:
            # load
            with open(cachefile, 'rb',) as f:
                recs = pickle.load(f, encoding="utf-8")
    
        # extract gt objects for this class
        class_recs = {}
        npos = 0
        for imagename in imagenames:
            R = [obj for obj in recs[imagename] if obj['name'] in classname]
            bbox = np.array([x['bbox'] for x in R])
            difficult = np.array([x['difficult'] for x in R]).astype(np.bool)
            det = [False] * len(R)
            npos = npos + sum(~difficult)
            class_recs[imagename] = {'bbox': bbox,
                                     'difficult': difficult,
                                     'det': det}
        
        # read dets
        detfile = detpath.format(classname)
        with open(detfile, 'r') as f:
            lines = f.readlines()
    
        splitlines = [x.strip().split(' ') for x in lines]
        image_ids = [x[0] for x in splitlines]
        confidence = np.array([float(x[1]) for x in splitlines])
        BB = np.array([[float(z) for z in x[2:]] for x in splitlines])
         
        if image_ids:
            pass
        else:
            print('empty')
            return 0,0,0
        
    
        # sort by confidence
        sorted_ind = np.argsort(-confidence)
        sorted_scores = np.sort(-confidence)
        BB = BB[sorted_ind, :]
        image_ids = [image_ids[x] for x in sorted_ind]
    
        # go down dets and mark TPs and FPs
        nd = len(image_ids)
        tp = np.zeros(nd)
        fp = np.zeros(nd)
        for d in range(nd):
            R = class_recs[image_ids[d]]
            bb = BB[d, :].astype(float)
            ovmax = -np.inf
            BBGT = R['bbox'].astype(float)
    
            if BBGT.size > 0:
                # compute overlaps
                # intersection
                ixmin = np.maximum(BBGT[:, 0], bb[0])
                iymin = np.maximum(BBGT[:, 1], bb[1])
                ixmax = np.minimum(BBGT[:, 2], bb[2])
                iymax = np.minimum(BBGT[:, 3], bb[3])
                iw = np.maximum(ixmax - ixmin + 1., 0.)
                ih = np.maximum(iymax - iymin + 1., 0.)
                inters = iw * ih
    
                # union
                uni = ((bb[2] - bb[0] + 1.) * (bb[3] - bb[1] + 1.) +
                       (BBGT[:, 2] - BBGT[:, 0] + 1.) *
                       (BBGT[:, 3] - BBGT[:, 1] + 1.) - inters)
    
                overlaps = inters / uni
                ovmax = np.max(overlaps)
                jmax = np.argmax(overlaps)
    
            if ovmax > ovthresh:
                if not R['difficult'][jmax]:
                    if not R['det'][jmax]:
                        tp[d] = 1.
                        R['det'][jmax] = 1
                    else:
                        fp[d] = 1.
            else:
                fp[d] = 1.
    
        # compute precision recall
        fp = np.cumsum(fp)
        tp = np.cumsum(tp)
        rec = tp / float(npos)
        # avoid divide by zero in case the first detection matches a difficult
        # ground truth
        prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps)
        ap = voc_ap(rec, prec, use_07_metric)
    
        return rec, prec, ap

    执行 calc_map.py 脚本,计算map

    4、yolo系列,ssd系列,以及faster-rcnn系列目标检测网络模型,我都会用以上方法统计map

    一般情况下,得分阈值设为 0.001, nms阈值设为0.45, iou阈值设为0.5,

    注意:本代码截取的 py-faster-rcnn 官方代码修改实现,数据集的格式为 pascal voc 官方数据集格式。 

  • 相关阅读:
    2013国内IT行业薪资对照表【技术岗位】
    Eclipse查看子类
    whereis 查找命令全路径
    开张了
    Ruby1.8中单行字符串写在多行
    FEMTO是什么
    FUSE文件系统
    魔兽私服pvpgn搭建
    linux网络源码分析(1)
    freehosting申请空间和ssh D设置
  • 原文地址:https://www.cnblogs.com/hypnus-ly/p/12931371.html
Copyright © 2011-2022 走看看