zoukankan      html  css  js  c++  java
  • Detectron2 训练+测试 代码框架

    1. run_predict.py

    import torch, torchvision
    import detectron2
    from detectron2.utils.logger import setup_logger
    setup_logger
    
    import numpy as np
    import os, json, cv2, random
    import matplotlib.pyplot as plt
    
    from detectron2 import model_zoo
    from detectron2.engine import DefaultPredictor
    from detectron2.config import get_cfg
    from detectron2.utils.visualizer import Visualizer
    from detectron2.data import MetadataCatalog, DatasetCatalog
    
    
    im = cv2.imread('./input.jpg')
    # cv2_imshow(im)
    
    cfg = get_cfg()
    # cfg.merge_from_file(model_zoo.get_config_file('COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml'))
    cfg.merge_from_file('../configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml')
    cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
    # cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url('COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml')
    cfg.MODEL.WEIGHTS = 'model_final_f10217.pkl'
    predictor = DefaultPredictor(cfg)
    outputs = predictor(im)
    
    # print(outputs['instances'].pred_classses)
    
    v = Visualizer(im[:, :, ::-1], MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=1.2)
    v = v.draw_instance_predictions(outputs['instances'].to('cpu'))
    plt.figure(figsize = (14, 10))
    plt.imshow(cv2.cvtColor(v.get_image()[:, :, ::-1], cv2.COLOR_BGR2RGB))
    plt.savefig('./output.jpg')
    

      

    2. run_train.py

    import torch, torchvision
    import detectron2
    from detectron2.utils.logger import setup_logger
    setup_logger
    
    import numpy as np
    import os, json, cv2, random
    import matplotlib.pyplot as plt
    
    from detectron2 import model_zoo
    from detectron2.engine import DefaultPredictor
    from detectron2.config import get_cfg
    from detectron2.utils.visualizer import Visualizer
    from detectron2.data import MetadataCatalog, DatasetCatalog
    
    from detectron2.structures import BoxMode
    
    
    # if your dataset is in COCO format, this cell can be replaced by the following three lines:
    # from detectron2.data.datasets import register_coco_instances
    # register_coco_instances("my_dataset_train", {}, "json_annotation_train.json", "path/to/image/dir")
    # register_coco_instances("my_dataset_val", {}, "json_annotation_val.json", "path/to/image/dir")
    
    
    ################### READ Data ####################
    def get_balloon_dicts(img_dir):
        json_file = os.path.join(img_dir, "via_region_data.json")
        with open(json_file) as f:
            imgs_anns = json.load(f)
    
        dataset_dicts = []
        for idx, v in enumerate(imgs_anns.values()):
            record = {}
            
            filename = os.path.join(img_dir, v["filename"])
            height, width = cv2.imread(filename).shape[:2]
            
            record["file_name"] = filename
            record["image_id"] = idx
            record["height"] = height
            record["width"] = width
          
            annos = v["regions"]
            objs = []
            for _, anno in annos.items():
                assert not anno["region_attributes"]
                anno = anno["shape_attributes"]
                px = anno["all_points_x"]
                py = anno["all_points_y"]
                poly = [(x + 0.5, y + 0.5) for x, y in zip(px, py)]
                poly = [p for x in poly for p in x]
    
                obj = {
                    "bbox": [np.min(px), np.min(py), np.max(px), np.max(py)],
                    "bbox_mode": BoxMode.XYXY_ABS,
                    "segmentation": [poly],
                    "category_id": 0,
                }
                objs.append(obj)
            record["annotations"] = objs
            dataset_dicts.append(record)
        return dataset_dicts
    
    
    for d in ["train", "val"]:
        DatasetCatalog.register("balloon_" + d, lambda d=d: get_balloon_dicts("dataset/balloon/" + d))
        MetadataCatalog.get("balloon_" + d).set(thing_classes=["balloon"])
    balloon_metadata = MetadataCatalog.get("balloon_train")
    
    dataset_dicts = get_balloon_dicts("dataset/balloon/train")
    for d in random.sample(dataset_dicts, 3):
        img = cv2.imread(d["file_name"])
        visualizer = Visualizer(img[:, :, ::-1], metadata=balloon_metadata, scale=0.5)
        out = visualizer.draw_dataset_dict(d)
        # cv2_imshow(out.get_image()[:, :, ::-1])
        plt.figure(figsize = (14, 10))
        plt.imshow(cv2.cvtColor(out.get_image()[:, :, ::-1], cv2.COLOR_BGR2RGB))
        plt.savefig('./read.jpg')
    
    
    ################### Train ####################
    from detectron2.engine import DefaultTrainer
    
    cfg = get_cfg()
    cfg.merge_from_file('../configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml')
    cfg.DATASETS.TRAIN = ("balloon_train",)
    cfg.DATASETS.TEST = ()
    cfg.DATALOADER.NUM_WORKERS = 2
    cfg.MODEL.WEIGHTS = 'model_final_f10217.pkl'
    cfg.SOLVER.IMS_PER_BATCH = 2
    cfg.SOLVER.BASE_LR = 0.00025  
    cfg.SOLVER.MAX_ITER = 300    
    cfg.SOLVER.STEPS = []        
    cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128   # (default: 512)
    cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1  # only has one class (ballon)
    
    os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
    trainer = DefaultTrainer(cfg) 
    trainer.resume_or_load(resume=False)
    trainer.train()
    

      

    3. print_intermediate_features.py

    import torch, torchvision
    import detectron2
    from detectron2.utils.logger import setup_logger
    setup_logger
    
    import numpy as np
    import os, json, cv2, random
    import matplotlib.pyplot as plt
    
    from detectron2 import model_zoo
    from detectron2.engine import DefaultPredictor
    from detectron2.config import get_cfg
    from detectron2.utils.visualizer import Visualizer
    from detectron2.data import MetadataCatalog, DatasetCatalog
    from detectron2.modeling import build_model
    from detectron2.modeling import build_backbone
    from detectron2.checkpoint import DetectionCheckpointer
    from detectron2.structures import ImageList
    
    
    ########### 读取数据 ############
    #im = cv2.imread('./input.jpg')
    # cv2_imshow(im)
    
    ########## 指定配置文件 #############
    cfg = get_cfg()
    # cfg.merge_from_file(model_zoo.get_config_file('COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml'))
    cfg.merge_from_file('../configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml')
    cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
    # cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url('COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml')
    cfg.MODEL.WEIGHTS = 'model_final_a54504.pkl'   # COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml
    
    
    
    ############# 处理输入图像:PIL转成tensor ##########
    image = cv2.imread('./input.jpg')
    height, width = image.shape[:2]
    image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))
    inputs = [{"image": image, "height": height, "width": width}]
    
    
    ############ 使用Models执行网络的一部分 ###############
    model = build_model(cfg)   # returns a torch.nn.Module with random parameters
    DetectionCheckpointer(model).load('model_final_a54504.pkl')
    
    model.eval()
    with torch.no_grad():
        images = model.preprocess_image(inputs)
        features = model.backbone(images.tensor)
        # outputs = model(image)
        # features = model.backbone(image)
    
    # features是一个dict:
    print(features.keys())
    with open('./print_features.txt_segmentation', 'w+') as f:
    
        print("features是一个字典,key包括['p2', 'p3', 'p4', 'p5', 'p6']", features['p2'], file=f)
    #print(type(model.named_children()))
    #print(model.named_children())
    
    '''
    for name, child in model.named_children():
        for i in child:
            print(i)
        #print(type(child))
        #print(type(name))
    '''
    '''
    child和name的type:
    <class 'detectron2.modeling.backbone.fpn.FPN'>
    <class 'str'>
    <class 'detectron2.modeling.proposal_generator.rpn.RPN'>
    <class 'str'>
    <class 'detectron2.modeling.roi_heads.roi_heads.StandardROIHeads'>
    <class 'str'>
    '''
    

      




    如果这篇文章帮助到了你,你可以请作者喝一杯咖啡

  • 相关阅读:
    axios
    vue打包之后生成一个配置文件修改请求接口
    微信小程序小结(2) ------ 自定义组件
    eros --- Windows Android真机调试
    weex前端式写法解决方案---eros
    微信小程序小结(1) ------ 前后端交互及wx.request的简易封装
    configparser模块--配置文件
    怎样尊重一个程序员
    poj1326(bfs)
    安装篇——压缩包安装MySql数据库
  • 原文地址:https://www.cnblogs.com/sddai/p/14430655.html
Copyright © 2011-2022 走看看