zoukankan      html  css  js  c++  java
  • mmdetection(一)安装及训练、测试VOC格式的数据

    一、安装

             https://github.com/open-mmlab/mmdetection/blob/master/docs/INSTALL.md

    二、训练自己的数据

     1、数据 

         mmdet的默认格式是coco的,这里就以voc格式为例,data下文件夹摆放位置如图

       2、训练

             (1)修改configs文件下的文件

               可先复制一份,然后自己命名一下。比如retinanet_x101_64x4d_fpn_1x.py,修改的部分主要是dataset settings部分,这部分可直接参考

             pascal_voc/faster_rcnn_r50_fpn_1x_voc0712.py(如下);还有一部分是修改该文件下的num_classes(类别数+1)

    # dataset settings
    dataset_type = 'VOCDataset'
    data_root = 'data/VOCdevkit/'
    img_norm_cfg = dict(
        mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
    train_pipeline = [
        dict(type='LoadImageFromFile'),
        dict(type='LoadAnnotations', with_bbox=True),
        dict(type='Resize', img_scale=(1000, 600), keep_ratio=True),
        dict(type='RandomFlip', flip_ratio=0.5),
        dict(type='Normalize', **img_norm_cfg),
        dict(type='Pad', size_divisor=32),
        dict(type='DefaultFormatBundle'),
        dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
    ]
    test_pipeline = [
        dict(type='LoadImageFromFile'),
        dict(
            type='MultiScaleFlipAug',
            img_scale=(1000, 600),
            flip=False,
            transforms=[
                dict(type='Resize', keep_ratio=True),
                dict(type='RandomFlip'),
                dict(type='Normalize', **img_norm_cfg),
                dict(type='Pad', size_divisor=32),
                dict(type='ImageToTensor', keys=['img']),
                dict(type='Collect', keys=['img']),
            ])
    ]
    data = dict(
        imgs_per_gpu=2,
        workers_per_gpu=2,
        train=dict(
            type='RepeatDataset',
            times=3,
            dataset=dict(
                type=dataset_type,
                ann_file=[
                    data_root + 'VOC2007/ImageSets/Main/trainval.txt'
                    
                ],
                img_prefix=[data_root + 'VOC2007/'],
                pipeline=train_pipeline)),
        val=dict(
            type=dataset_type,
            ann_file=data_root + 'VOC2007/ImageSets/Main/test.txt',
            img_prefix=data_root + 'VOC2007/',
            pipeline=test_pipeline),
        test=dict(
            type=dataset_type,
            ann_file=data_root + 'VOC2007/ImageSets/Main/test.txt',
            img_prefix=data_root + 'VOC2007/',
            pipeline=test_pipeline))
    evaluation = dict(interval=1, metric='mAP')
    # optimizer
    optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
    optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
    # learning policy
    lr_config = dict(policy='step', step=[3])  # actual epoch = 3 * 3 = 9
    checkpoint_config = dict(interval=1)
    # yapf:disable
    log_config = dict(
        interval=50,
        hooks=[
            dict(type='TextLoggerHook'),
            # dict(type='TensorboardLoggerHook')
        ])
    # yapf:enable
    # runtime settings
    total_epochs = 100  # actual epoch = 4 * 3 = 12
    dist_params = dict(backend='nccl')
    log_level = 'INFO'
    work_dir = './work_dirs/faster_rcnn_r50_fpn_1x_voc0712'
    load_from = None
    resume_from = None
    workflow = [('train', 1)]
    View Code

              (2)修改mmdet/datasets/voc.py下classes为自己的类

              (3)训练

          python tools/train.py configs/pascal_voc/faster_rcnn_r50_fpn_1x_voc0712_my.py 

     3、测试

          (1)输出mAP

         修改mmdetection/mmdet/core/evaluation   voc_classes()返回自己的类

            python3 tools/test.py configs/pascal_voc/faster_rcnn_r50_fpn_1x_voc0712_my.py work_dirs/faster_rcnn_r50_fpn_1x_voc0712/latest.pth --eval mAP    --show

        (2)  测试单张图片

            参考 demo/webcam_demo.py,

    python demo/img_demo.py configs/pascal_voc/faster_rcnn_r50_fpn_1x_voc0712_my.py work_dirs/faster_rcnn_r50_fpn_1x_voc0712/latest.pth demo/2017-09-05-161908.jpg

            

    import argparse
    import torch
    
    from mmdet.apis import inference_detector, init_detector, show_result
    
    
    def parse_args():
        parser = argparse.ArgumentParser(description='MMDetection image demo')
        parser.add_argument('config', help='test config file path')
        parser.add_argument('checkpoint', help='checkpoint file')
        parser.add_argument('imagepath', help='the path of image to test')
        parser.add_argument('--device', type=int, default=0, help='CUDA device id')
        parser.add_argument(
            '--score-thr', type=float, default=0.5, help='bbox score threshold')
        args = parser.parse_args()
        return args
    
    
    def main():
        args = parse_args()
    
        model = init_detector(
            args.config, args.checkpoint, device=torch.device('cuda', args.device))
    
        result = inference_detector(model, args.imagepath)
     
        show_result(
            args.imagepath, result, model.CLASSES, score_thr=args.score_thr, wait_time=0)
    
    
    if __name__ == '__main__':
        main()

    参考

    https://zhuanlan.zhihu.com/p/101202864

    https://blog.csdn.net/laizi_laizi/article/details/104256781

  • 相关阅读:
    java框架
    MVC编程模式
    java各版本简单对比
    java设计模式
    ES中TF-IDF算法
    es分词器
    java应用零停机,时间索引重建(reindex)
    Spring源码由浅入深系列一 简介
    Spring源代码解析(收藏)
    spring源码读书笔记
  • 原文地址:https://www.cnblogs.com/573177885qq/p/12734630.html
Copyright © 2011-2022 走看看