zoukankan      html  css  js  c++  java
  • Detectron2训练visdrone记录

    准备

    VOC标签转换参见这篇
    注意:object_name = name_dict[box[4]] 改为 object_name = name_dict[box[5]]。为了与detectron2统一,
    标签文件夹命名为Annotations,图片文件夹命名为JPEGImages,train.txt位于xxx/ImageSets/Main/。

    train

    构建instance

    # -*- coding: utf-8 -*-
    # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
    
    import numpy as np
    import os
    import xml.etree.ElementTree as ET
    from fvcore.common.file_io import PathManager
    
    from detectron2.data import DatasetCatalog, MetadataCatalog
    from detectron2.structures import BoxMode
    
    __all__ = ["register_visdrone_voc"]
    
    CLASS_NAMES = ['__background__',  # always index 0
                   'pedestrian', 'people', 'bicycle', 'car', 'van', 'truck', 'tricycle', 'awning-tricycle', 'bus', 'motor']
    
    
    def load_voc_instances(dirname: str, split: str):
        """
        Load Pascal VOC detection annotations to Detectron2 format.
    
        Args:
            dirname: Contain "Annotations", "ImageSets", "JPEGImages"
            split (str): one of "train", "test", "val", "trainval"
        """
        with PathManager.open(os.path.join(dirname, "ImageSets", "Main", split + ".txt")) as f:
            fileids = np.loadtxt(f, dtype=np.str)
    
        # Needs to read many small annotation files. Makes sense at local
        annotation_dirname = PathManager.get_local_path(os.path.join(dirname, "Annotations/"))
        dicts = []
        for fileid in fileids:
            anno_file = os.path.join(annotation_dirname, fileid + ".xml")
            jpeg_file = os.path.join(dirname, "JPEGImages", fileid + ".jpg")
    
            with PathManager.open(anno_file) as f:
                tree = ET.parse(f)
    
            r = {
                "file_name": jpeg_file,
                "image_id": fileid,
                "height": int(tree.findall("./size/height")[0].text),
                "width": int(tree.findall("./size/width")[0].text),
            }
            instances = []
    
            for obj in tree.findall("object"):
                cls = obj.find("name").text
                # We include "difficult" samples in training.
                # Based on limited experiments, they don't hurt accuracy.
                # difficult = int(obj.find("difficult").text)
                # if difficult == 1:
                # continue
                bbox = obj.find("bndbox")
                bbox = [float(bbox.find(x).text) for x in ["xmin", "ymin", "xmax", "ymax"]]
                # Original annotations are integers in the range [1, W or H]
                # Assuming they mean 1-based pixel indices (inclusive),
                # a box with annotation (xmin=1, xmax=W) covers the whole image.
                # In coordinate space this is represented by (xmin=0, xmax=W)
                bbox[0] -= 1.0
                bbox[1] -= 1.0
                instances.append(
                    {"category_id": CLASS_NAMES.index(cls), "bbox": bbox, "bbox_mode": BoxMode.XYXY_ABS}
                )
            r["annotations"] = instances
            dicts.append(r)
        return dicts
    
    
    def register_visdrone_voc(name, dirname, split, year):
        DatasetCatalog.register(name, lambda: load_voc_instances(dirname, split))
        MetadataCatalog.get(name).set(
            thing_classes=CLASS_NAMES, dirname=dirname, year=year, split=split
        )
    

    train采用Faster R-CNN with FPN,backbone使用Resnext-101,群卷积32x8d,即32个group,每个group8个filter,
    注意,如果因image图片损坏无法训练,修改PIL库的ImageFile.py,将LOAD_TRUNCATED_IMAGES 改为 True

    from detectron2.engine import DefaultTrainer
    from detectron2.config import get_cfg
    from detectron2.checkpoint import DetectionCheckpointer
    from detectron2.engine import DefaultPredictor
    from detectron2.data import DatasetCatalog, MetadataCatalog
    from detectron2.evaluation import COCOEvaluator, PascalVOCDetectionEvaluator, inference_on_dataset
    from detectron2.data import build_detection_test_loader
    from visdrone_voc import *
    import os
    import cv2
    import torch
    
    register_visdrone_voc('VISDRONE_VOC', os.path.join('/home/chenzhengxi/data/VisDrone/VisDrone2018-DET-train'),
                          'train', 2012)
    register_visdrone_voc('VISDRONE_VAL', os.path.join('/home/chenzhengxi/data/VisDrone/VisDrone2018-DET-val'),
                          'val', 2012)
    register_visdrone_voc('VISDRONE_TEST', os.path.join('/home/chenzhengxi/data/VisDrone/VisDrone2019-DET-test-dev'),
                          'test', 2012)
    cfg = get_cfg()
    cfg.merge_from_file('configs/faster_rcnn_X_101_32x8d_FPN_3x.yaml')
    
    # cfg.DATASETS.TEST = ()
    cfg.DATALOADER.NUM_WORKERS = 4
    cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128
    
    os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
    trainer = DefaultTrainer(cfg)
    #resume=True可继续训练并加载最新权重
    trainer.resume_or_load(resume=False)
    trainer.train()
    
    #以下代码可指定具体权重
    #cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_0229999.pth")
    #checkpointer = DetectionCheckpointer(trainer.model)
    #checkpointer.load(cfg.MODEL.WEIGHTS)
    cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7   # set the testing threshold for this model
    
    evaluator = PascalVOCDetectionEvaluator(cfg.DATASETS.TEST[0])
    # val_loader = build_detection_test_loader(cfg, "VISDRONE_VAL")
    # result_val = inference_on_dataset(trainer.model, val_loader, evaluator)
    # print(result_val)
    print(trainer.test(cfg, trainer.model, evaluator))
    
    # predictor = DefaultPredictor(cfg)
    # im = cv2.imread('/home/chenzhengxi/data/VisDrone/VisDrone2018-DET-val/JPEGImages/0000026_03500_d_0000031.jpg')
    # outputs = predictor(im)
    # ooo = outputs['instances'].to(torch.device("cpu"))
    # boxes = ooo.pred_boxes.tensor.numpy()
    # print(boxes)
    # for i in range(len(boxes)):
    #     cv2.rectangle(im, tuple(boxes[i, 0:2]), tuple(boxes[i, 2:4]), (0, 255, 0), 2)
    #
    # cv2.imshow('visdrone', im)
    # cv2.waitKey(0)
    

    测试记录

    注意:数据集没有__background__,计算AP是会除0,修改pascal_voc_evaluation.py Line 90,

                for cls_id, cls_name in enumerate(self._class_names):
    +++             if cls_id == 0:  # __background__
    +++                 continue
                    lines = predictions.get(cls_id, [""])
    
    AP AP50 AP75 iter datasets
    20.8397 39.0423 19.9213 94999 val
    17.0480 32.8580 16.1755 94999 test
    22.2992 40.0259 21.7078 169999 val
    18.0168 33.5292 17.6131 169999 test
    22.9258 41.0643 22.4904 214999 val
    18.1249 33.7228 17.6461 214999 test
    22.8556 40.9861 22.3667 269999 val
    18.0256 33.5866 17.5259 269999 test

    可以看出效果远高于yolo,最终配置和权重下载,提取码: 74s4

  • 相关阅读:
    基于k8s搭建微服务日志收集中心
    分析java堆内存满时那些类占用内存居多
    yizimi 在 DMG 的板子库 (数据结构与算法)
    Contest 3/14
    基础算法训练1
    图论基础训练
    2021.03.09随笔
    树莓派 4B 安装 openEuler
    Docker 学习准备
    关于备案第二个服务器时遇到的问题
  • 原文地址:https://www.cnblogs.com/chenzhengxi/p/13065792.html
Copyright © 2011-2022 走看看