zoukankan      html  css  js  c++  java
  • Detectron2训练visdrone记录

    准备

    VOC标签转换参见这篇
    注意:object_name = name_dict[box[4]] 改为 object_name = name_dict[box[5]]。为了与detectron2统一,
    标签文件夹命名为Annotations,图片文件夹命名为JPEGImages,train.txt位于xxx/ImageSets/Main/。

    train

    构建instance

    # -*- coding: utf-8 -*-
    # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
    
    import numpy as np
    import os
    import xml.etree.ElementTree as ET
    from fvcore.common.file_io import PathManager
    
    from detectron2.data import DatasetCatalog, MetadataCatalog
    from detectron2.structures import BoxMode
    
    __all__ = ["register_visdrone_voc"]
    
    CLASS_NAMES = ['__background__',  # always index 0
                   'pedestrian', 'people', 'bicycle', 'car', 'van', 'truck', 'tricycle', 'awning-tricycle', 'bus', 'motor']
    
    
    def load_voc_instances(dirname: str, split: str):
        """
        Load Pascal VOC detection annotations to Detectron2 format.
    
        Args:
            dirname: Contain "Annotations", "ImageSets", "JPEGImages"
            split (str): one of "train", "test", "val", "trainval"
        """
        with PathManager.open(os.path.join(dirname, "ImageSets", "Main", split + ".txt")) as f:
            fileids = np.loadtxt(f, dtype=np.str)
    
        # Needs to read many small annotation files. Makes sense at local
        annotation_dirname = PathManager.get_local_path(os.path.join(dirname, "Annotations/"))
        dicts = []
        for fileid in fileids:
            anno_file = os.path.join(annotation_dirname, fileid + ".xml")
            jpeg_file = os.path.join(dirname, "JPEGImages", fileid + ".jpg")
    
            with PathManager.open(anno_file) as f:
                tree = ET.parse(f)
    
            r = {
                "file_name": jpeg_file,
                "image_id": fileid,
                "height": int(tree.findall("./size/height")[0].text),
                "width": int(tree.findall("./size/width")[0].text),
            }
            instances = []
    
            for obj in tree.findall("object"):
                cls = obj.find("name").text
                # We include "difficult" samples in training.
                # Based on limited experiments, they don't hurt accuracy.
                # difficult = int(obj.find("difficult").text)
                # if difficult == 1:
                # continue
                bbox = obj.find("bndbox")
                bbox = [float(bbox.find(x).text) for x in ["xmin", "ymin", "xmax", "ymax"]]
                # Original annotations are integers in the range [1, W or H]
                # Assuming they mean 1-based pixel indices (inclusive),
                # a box with annotation (xmin=1, xmax=W) covers the whole image.
                # In coordinate space this is represented by (xmin=0, xmax=W)
                bbox[0] -= 1.0
                bbox[1] -= 1.0
                instances.append(
                    {"category_id": CLASS_NAMES.index(cls), "bbox": bbox, "bbox_mode": BoxMode.XYXY_ABS}
                )
            r["annotations"] = instances
            dicts.append(r)
        return dicts
    
    
    def register_visdrone_voc(name, dirname, split, year):
        DatasetCatalog.register(name, lambda: load_voc_instances(dirname, split))
        MetadataCatalog.get(name).set(
            thing_classes=CLASS_NAMES, dirname=dirname, year=year, split=split
        )
    

    train采用Faster R-CNN with FPN,backbone使用Resnext-101,群卷积32x8d,即32个group,每个group8个filter,
    注意,如果因image图片损坏无法训练,修改PIL库的ImageFile.py,将LOAD_TRUNCATED_IMAGES 改为 True

    from detectron2.engine import DefaultTrainer
    from detectron2.config import get_cfg
    from detectron2.checkpoint import DetectionCheckpointer
    from detectron2.engine import DefaultPredictor
    from detectron2.data import DatasetCatalog, MetadataCatalog
    from detectron2.evaluation import COCOEvaluator, PascalVOCDetectionEvaluator, inference_on_dataset
    from detectron2.data import build_detection_test_loader
    from visdrone_voc import *
    import os
    import cv2
    import torch
    
    register_visdrone_voc('VISDRONE_VOC', os.path.join('/home/chenzhengxi/data/VisDrone/VisDrone2018-DET-train'),
                          'train', 2012)
    register_visdrone_voc('VISDRONE_VAL', os.path.join('/home/chenzhengxi/data/VisDrone/VisDrone2018-DET-val'),
                          'val', 2012)
    register_visdrone_voc('VISDRONE_TEST', os.path.join('/home/chenzhengxi/data/VisDrone/VisDrone2019-DET-test-dev'),
                          'test', 2012)
    cfg = get_cfg()
    cfg.merge_from_file('configs/faster_rcnn_X_101_32x8d_FPN_3x.yaml')
    
    # cfg.DATASETS.TEST = ()
    cfg.DATALOADER.NUM_WORKERS = 4
    cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128
    
    os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
    trainer = DefaultTrainer(cfg)
    #resume=True可继续训练并加载最新权重
    trainer.resume_or_load(resume=False)
    trainer.train()
    
    #以下代码可指定具体权重
    #cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_0229999.pth")
    #checkpointer = DetectionCheckpointer(trainer.model)
    #checkpointer.load(cfg.MODEL.WEIGHTS)
    cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7   # set the testing threshold for this model
    
    evaluator = PascalVOCDetectionEvaluator(cfg.DATASETS.TEST[0])
    # val_loader = build_detection_test_loader(cfg, "VISDRONE_VAL")
    # result_val = inference_on_dataset(trainer.model, val_loader, evaluator)
    # print(result_val)
    print(trainer.test(cfg, trainer.model, evaluator))
    
    # predictor = DefaultPredictor(cfg)
    # im = cv2.imread('/home/chenzhengxi/data/VisDrone/VisDrone2018-DET-val/JPEGImages/0000026_03500_d_0000031.jpg')
    # outputs = predictor(im)
    # ooo = outputs['instances'].to(torch.device("cpu"))
    # boxes = ooo.pred_boxes.tensor.numpy()
    # print(boxes)
    # for i in range(len(boxes)):
    #     cv2.rectangle(im, tuple(boxes[i, 0:2]), tuple(boxes[i, 2:4]), (0, 255, 0), 2)
    #
    # cv2.imshow('visdrone', im)
    # cv2.waitKey(0)
    

    测试记录

    注意:数据集没有__background__,计算AP是会除0,修改pascal_voc_evaluation.py Line 90,

                for cls_id, cls_name in enumerate(self._class_names):
    +++             if cls_id == 0:  # __background__
    +++                 continue
                    lines = predictions.get(cls_id, [""])
    
    AP AP50 AP75 iter datasets
    20.8397 39.0423 19.9213 94999 val
    17.0480 32.8580 16.1755 94999 test
    22.2992 40.0259 21.7078 169999 val
    18.0168 33.5292 17.6131 169999 test
    22.9258 41.0643 22.4904 214999 val
    18.1249 33.7228 17.6461 214999 test
    22.8556 40.9861 22.3667 269999 val
    18.0256 33.5866 17.5259 269999 test

    可以看出效果远高于yolo,最终配置和权重下载,提取码: 74s4

  • 相关阅读:
    【CodeForces】[659C]Tanya and Toys
    【CodeForces】[659A]Round House
    高并发网络编程之epoll详解
    Linux写时拷贝技术(copy-on-write)
    5种服务器网络编程模型讲解
    5种服务器网络编程模型讲解
    当你输入一个网址的时候,实际会发生什么?
    error: std::ios_base::ios_base(const std::ios_base&)’是私有的
    C++和JAVA的区别
    为什么内联函数,构造函数,静态成员函数不能为virtual函数
  • 原文地址:https://www.cnblogs.com/chenzhengxi/p/13065792.html
Copyright © 2011-2022 走看看