zoukankan      html  css  js  c++  java
  • 调用训练好的detectron模型

    from __future__ import absolute_import
    from __future__ import division
    from __future__ import print_function
    from __future__ import unicode_literals
    
    from collections import defaultdict
    import cv2  # NOQA (Must import before importing caffe2 due to bug in cv2)
    from caffe2.python import workspace
    from detectron.core.config import assert_and_infer_cfg
    # from detectron.core.config import cfg
    from detectron.core.config import merge_cfg_from_file
    from detectron.utils.io import cache_url
    from detectron.utils.timer import Timer
    import detectron.core.test_engine as infer_engine
    import detectron.datasets.dummy_datasets as dummy_datasets
    import detectron.utils.c2 as c2_utils
    import detectron.utils.vis as vis_utils
    import numpy as np
    import pycocotools.mask as mask_util
    c2_utils.import_detectron_ops()
    # OpenCL may be enabled by default in OpenCV3; disable it because it's not
    # thread safe and causes unwanted GPU memory allocations.
    # cv2.ocl.setUseOpenCL(False)
    #coco
    # weights = "/home/gaomh/Desktop/test/cocomodel/model_final.pkl"
    # config = "/home/gaomh/Desktop/test/cocomodel/e2e_mask_rcnn_R-101-FPN_1x.yaml"
    #hat
    weights = "/home/gaomh/Desktop/test/models/kp-person/model_final.pkl"
    config = "/home/gaomh/Desktop/test/models/kp-person/e2e_keypoint_rcnn_X-101-32x8d-FPN_1x.yaml"
    #foot
    # weights = "/home/gaomh/Desktop/test/trainMOdel/train/voc_2007_train/retinanet/model_final.pkl"
    # config = "/home/gaomh/Desktop/test/trainMOdel/train/voc_2007_train/retinanet_R-50-FPN_1x.0.yaml"
    gpuid = 0
    workspace.GlobalInit(['caffe2', '--caffe2_log_level=0'])
    merge_cfg_from_file(config)
    assert_and_infer_cfg(cache_urls=False)
    
    model = infer_engine.initialize_model_from_cfg(weights, gpuid)
    dataset = dummy_datasets.get_foot_dataset()
    
    
    def convert_from_cls_format(cls_boxes, cls_segms):
        """Convert from the class boxes/segms/keyps format generated by the testing
        code.
        """
        box_list = [b for b in cls_boxes if len(b) > 0]
        if len(box_list) > 0:
            boxes = np.concatenate(box_list)
        else:
            boxes = None
        if cls_segms is not None:
            segms = [s for slist in cls_segms for s in slist]
        else:
            segms = None
        classes = []
        for j in range(len(cls_boxes)):
            classes += [j] * len(cls_boxes[j])
        return boxes, segms, classes
    
    
    def vis_one_image(boxes, cls_segms, thresh=0.9):
        """Visual debugging of detections."""
        result_box = []
        result_mask = []
        if isinstance(boxes, list):
            boxes, segms, classes = convert_from_cls_format(boxes,cls_segms)
    
        if boxes is None or boxes.shape[0] == 0 or max(boxes[:, 4]) < thresh:
            return result_box,result_mask
        if segms is not None:
            masks=mask_util.decode(segms)
    
        # Display in largest to smallest order to reduce occlusion
        areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
        sorted_inds = np.argsort(-areas)
    
        for i in sorted_inds:
            bbox = boxes[i, :4]
            score = boxes[i, -1]
            if score < thresh:
                continue
            result_box.append([dataset.classes[classes[i]], score, int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])])
            if segms is not None and len(segms)>i:
                result_mask.append(masks[:,:,i])
            else:
                result_mask.append([])
    
        return result_box, result_mask
    
    # cap = cv2.VideoCapture("rtsp://192.168.123.231")
    cap = cv2.VideoCapture("/home/gaomh/per.mp4")
    # cv2.namedWindow("img", cv2.WINDOW_NORMAL)
    while cap.isOpened():
        res, frame = cap.read()
        timers = defaultdict(Timer)
        with c2_utils.NamedCudaScope(0):
            cls_boxes, cls_segms, cls_keyps = infer_engine.im_detect_all(
                model, frame, None, timers=timers
            )
        # img = vis_utils.vis_one_image_opencv(im=frame, boxes=cls_boxes, segms=cls_segms, keypoints=cls_keyps, thresh=0.7, kp_thresh=2, show_box=False
        #                                      ,dataset=dataset, show_class=True)
        # vis_utils
    
        result_box, result_mask = vis_one_image(cls_boxes, cls_segms)
        print(result_box)
        for box in result_box:
            tit = box[0]
            thr = box[1]
            left = box[2]
            top = box[3]
            right = box[4]
            bottom = box[5]
            # if tit is "person":
            cv2.rectangle(frame, (left, top), (right, bottom), (255, 0, 0), 1)
            cv2.putText(frame, tit, (left-10, top-10), cv2.FONT_HERSHEY_COMPLEX, 0.4, (0, 0, 255))
        # print(result_box)
        cv2.imshow("img", frame)
        key = cv2.waitKey(1)
        if key == ord("q"):
            break
    
    cv2.destroyAllWindows()

    修改dummy_datasets.py,增加相应分类

    def get_foot_dataset():
        """A dummy COCO dataset that includes only the 'classes' field."""
        ds = AttrDict()
        classes = [
            '__background__', 'person', 'foot', 'car'
        ]
        ds.classes = {i: name for i, name in enumerate(classes)}
        return ds

    效果图

  • 相关阅读:
    N天学习一个linux命令之lsof
    N天学习一个linux命令之ps
    N天学习一个linux命令之yum
    N天学习一个linux命令之rsync
    N天学习一个linux命令之ss
    N天学习一个linux命令之netstat
    N天学习一个linux命令之vmstat
    N天学习一个linux命令之sort
    N天学习一个linux命令之rpm
    跨域问题
  • 原文地址:https://www.cnblogs.com/answerThe/p/12121176.html
Copyright © 2011-2022 走看看