zoukankan      html  css  js  c++  java
  • yolov3计算mAP

    搜了半天,都是要下载voc_eval.py文件,但该文件只能在python2下运行。以下是python3版本的:

    # coding:utf-8
    import xml.etree.ElementTree as ET
    import os
    #import cPickle
    import _pickle as cPickle
    import numpy as np
     
    def parse_rec(filename):
        """ Parse a PASCAL VOC xml file """
        tree = ET.parse(filename)
        objects = []
        for obj in tree.findall('object'):
            obj_struct = {}
            obj_struct['name'] = obj.find('name').text
            obj_struct['pose'] = obj.find('pose').text
            obj_struct['truncated'] = int(obj.find('truncated').text)
            obj_struct['difficult'] = int(obj.find('difficult').text)
            bbox = obj.find('bndbox')
            obj_struct['bbox'] = [int(bbox.find('xmin').text),
                                  int(bbox.find('ymin').text),
                                  int(bbox.find('xmax').text),
                                  int(bbox.find('ymax').text)]
            objects.append(obj_struct)
     
        return objects
     
    def voc_ap(rec, prec, use_07_metric=False):
        """ ap = voc_ap(rec, prec, [use_07_metric])
        Compute VOC AP given precision and recall.
        If use_07_metric is true, uses the
        VOC 07 11 point method (default:False).
        """
        if use_07_metric:
            # 11 point metric
            ap = 0.
            for t in np.arange(0., 1.1, 0.1):
                if np.sum(rec >= t) == 0:
                    p = 0
                else:
                    p = np.max(prec[rec >= t])
                ap = ap + p / 11.
        else:
            # correct AP calculation
            # first append sentinel values at the end
            mrec = np.concatenate(([0.], rec, [1.]))
            mpre = np.concatenate(([0.], prec, [0.]))
     
            # compute the precision envelope
            for i in range(mpre.size - 1, 0, -1):
                mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
     
            # to calculate area under PR curve, look for points
            # where X axis (recall) changes value
            i = np.where(mrec[1:] != mrec[:-1])[0]
     
            # and sum (\Delta recall) * prec
            ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
        return ap
     
    def voc_eval(detpath,
                 annopath,
                 imagesetfile,
                 classname,
                 cachedir,
                 ovthresh=0.5,
                 use_07_metric=False):
        """rec, prec, ap = voc_eval(detpath,
                                    annopath,
                                    imagesetfile,
                                    classname,
                                    [ovthresh],
                                    [use_07_metric])
        Top level function that does the PASCAL VOC evaluation.
        detpath: Path to detections
            detpath.format(classname) should produce the detection results file.
        annopath: Path to annotations
            annopath.format(imagename) should be the xml annotations file.
        imagesetfile: Text file containing the list of images, one image per line.
        classname: Category name (duh)
        cachedir: Directory for caching the annotations
        [ovthresh]: Overlap threshold (default = 0.5)
        [use_07_metric]: Whether to use VOC07's 11 point AP computation
            (default False)
        """
        # assumes detections are in detpath.format(classname)
        # assumes annotations are in annopath.format(imagename)
        # assumes imagesetfile is a text file with each line an image name
        # cachedir caches the annotations in a pickle file
     
        # first load gt
        if not os.path.isdir(cachedir):
            os.mkdir(cachedir)
        cachefile = os.path.join(cachedir, 'annots.pkl')
        
        # read list of images
        with open(imagesetfile, 'r') as f:
            lines = f.readlines()
        imagenames = [x.strip() for x in lines] #文件名
        
        if not os.path.isfile(cachefile):
            #print("zaybnzazazazazazazaza")
            # load annots
            recs = {}
            for i, imagename in enumerate(imagenames):
                recs[imagename] = parse_rec(annopath.format(imagename))
                if i % 100 == 0:
                    print('Reading annotation for {:d}/{:d}'.format(
                        i + 1, len(imagenames)))
            # save
            print('Saving cached annotations to {:s}'.format(cachefile))
            with open(cachefile, 'wb') as f:
                cPickle.dump(recs, f)
        else:
            # load
            with open(cachefile, 'rb') as f:
                try:
                    recs = cPickle.load(f)
                except EOFError:
                    return 
                #recs = cPickle.load(f)
        # extract gt objects for this class
        class_recs = {}
        npos = 1
        for imagename in imagenames:
            R = [obj for obj in recs[imagename] if obj['name'] == classname]
            bbox = np.array([x['bbox'] for x in R])
            difficult = np.array([x['difficult'] for x in R]).astype(np.bool)
            det = [False] * len(R)
            npos = npos + sum(~difficult)
            class_recs[imagename] = {'bbox': bbox,
                                     'difficult': difficult,
                                     'det': det}
        # read dets
        detfile = detpath.format(classname)
        with open(detfile, 'rb+') as f:
            lines = f.readlines()
        #print("type(lines[0]):",type(lines[0]))
        #print("type(x):",type(str(lines[0]).strip().split(" ")))
        splitlines = [str(x).strip().split(' ') for x in lines]
        
        #splitlines = splitlines.encode()
        #print(type(splitlines))
        image_ids = [x[0] for x in splitlines]
        confidence = np.array([float(x[1]) for x in splitlines])
        a = "\\n'"
    #     for x in splitlines:
    #         for z in x[2:]:
    #             if a in z:
    #                 print(z[:len(z)-3])
    #             else:
    #                 print(z)
                
                
        #remove \n
        BB = np.array([[float(z) if a not in z else float(z[:len(z)-3]) for z in x[2:]] for x in splitlines])
        #print(BB)
     
        # sort by confidence
        sorted_ind = np.argsort(-confidence)
        sorted_scores = np.sort(-confidence)
        BB = BB[sorted_ind, :]
        image_ids = [image_ids[x] for x in sorted_ind]
     
        # go down dets and mark TPs and FPs
        nd = len(image_ids)
        tp = np.zeros(nd)
        fp = np.zeros(nd)
        for d in range(nd):
            #print(image_ids[d][2:])
            #print(class_recs)
            R = class_recs[image_ids[d][2:]]
            bb = BB[d, :].astype(float)
            ovmax = -np.inf
            BBGT = R['bbox'].astype(float)
     
            if BBGT.size > 0:
                # compute overlaps
                # intersection
                ixmin = np.maximum(BBGT[:, 0], bb[0])
                iymin = np.maximum(BBGT[:, 1], bb[1])
                ixmax = np.minimum(BBGT[:, 2], bb[2])
                iymax = np.minimum(BBGT[:, 3], bb[3])
                iw = np.maximum(ixmax - ixmin + 1., 0.)
                ih = np.maximum(iymax - iymin + 1., 0.)
                inters = iw * ih
     
                # union
                uni = ((bb[2] - bb[0] + 1.) * (bb[3] - bb[1] + 1.) +
                       (BBGT[:, 2] - BBGT[:, 0] + 1.) *
                       (BBGT[:, 3] - BBGT[:, 1] + 1.) - inters)
     
                overlaps = inters / uni
                ovmax = np.max(overlaps)
                jmax = np.argmax(overlaps)
     
            if ovmax > ovthresh:
                if not R['difficult'][jmax]:
                    if not R['det'][jmax]:
                        tp[d] = 1.
                        R['det'][jmax] = 1
                    else:
                        fp[d] = 1.
                fp[d] = 1.
     
        # compute precision recall
        fp = np.cumsum(fp)
        tp = np.cumsum(tp)
        rec = tp / float(npos)
        # avoid divide by zero in case the first detection matches a difficult
        # ground truth
        prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps)
        ap = voc_ap(rec, prec, use_07_metric)
     
        return rec, prec, ap
    

    文件放在yolov3根目录下,然后编写计算文件,computer_single_all_map.py

    from voc_eval import voc_eval
    
    import os
    
    current_path = os.getcwd()
    results_path = current_path+"/results"
    sub_files = os.listdir(results_path)
    
    mAP = []
    for i in range(len(sub_files)):
        class_name = sub_files[i].split(".txt")[0]
        rec, prec, ap = voc_eval('./results/{}.txt', './xml/{}.xml', './data/pig_val_map.txt', class_name, '.')
        print("{} :\t {} ".format(class_name, ap))
        mAP.append(ap)
    
    mAP = tuple(mAP)
    
    print("***************************")
    print("mAP :\t {}".format( float( sum(mAP)/len(mAP)) )) 
    

    需要改3个位置:

    ./results/{}.txt改为存放验证文件的路径,如图;

    ./xml/{}.xml改为存放数据集xml文件的路径,如图;

    ./data/pig_val_map.txt改为验证文件的路径(该文件只填写数据集文件的名称,不要加路径和后缀),如图;

    生成验证文件
    ./darknet detector valid ./cfg/pig.data ./cfg/yolov3-pig.cfg backup/yolov3-pig_last.weights -out "" -gpu 0 -thresh .5

    会在result目录下生成各分类的验证文件,我只有一个分类:pig,所以生成了pig.txt。

    计算mAP
    rm -f annots.pkl
    python computer_Single_ALL_mAP.py.py
    每次重新计算,都要删掉annots.pkl文件。

    参考:YOLOv3 训练自己的数据附优化与问题总结 (mamicode.com)

    [yolov3]./darknet detector valid “eval: Using default ‘voc‘ 4 段错误“_yuchen的博客-CSDN博客

    YOLOv3 mAP计算教程_阿木寺的博客-CSDN博客_yolov3计算map

    YOLO V3计算mAP (voc_eval.py)适用于python3_菜鸟的进阶之路1D的博客-CSDN博客

  • 相关阅读:
    C#之枚举
    C#之判断字母大小、字母转ACII码
    C#之BF算法
    md5如何实现encodePassword加密方法
    基本配置及安全级别security-level
    js中“原生”map
    web.xml讲解
    java application指的是什么
    .conf、.bak是什么格式
    Maven系列--web.xml 配置详解
  • 原文地址:https://www.cnblogs.com/codeit/p/15748565.html
Copyright © 2011-2022 走看看