zoukankan      html  css  js  c++  java
  • Tensorflow Mask-RCNN(三)——实时 检测视频

    参考:https://www.youtube.com/watch?v=lLM8oAsi32g

        import cv2
        import numpy as np
         
         
        def random_colors(N):
            np.random.seed(1)
            colors=[tuple(255*np.random.rand(3)) for _ in range(N)]
            return colors
         
        def apply_mask(image, mask, color, alpha=0.5):
            """Apply the given mask to the image.
            """
            for n, c in enumerate(color):
                image[:, :, n] = np.where(
                    mask == 1,
                    image[:, :, n] *(1 - alpha) + alpha * c,
                    image[:, :, n]
                )
            return image
         
        def display_instances(image,boxes,masks,ids,names,scores):
            n_instances=boxes.shape[0]
            if not n_instances:
                print('No instances to display')
            else:
                assert boxes.shape[0] == masks.shape[-1] == ids.shape[0]
            
            colors=random_colors(n_instances)
            height, width = image.shape[:2]
            
            for i,color in enumerate(colors):
                if not np.any(boxes[i]):
                    continue
                
                y1,x1,y2,x2=boxes[i]
                mask=masks[:,:,i]
                image=apply_mask(image,mask,color)
                image=cv2.rectangle(image,(x1,y1),(x2,y2),color,2)
                
                label=names[ids[i]]
                score=scores[i] if scores is not None else None
                
                caption='{}{:.2f}'.format(label,score) if score else label
                image=cv2.putText(
                    image,caption,(x1,y1),cv2.FONT_HERSHEY_COMPLEX,0.7,color,2
                )
                
            return image
         
        if __name__=='__main__':
            import os
            import sys
            import random
            import math
            import skimage.io
            import time
            import utils
            #import model as modellib
            
            
            ROOT_DIR = os.path.abspath("../")
            sys.path.append(ROOT_DIR)
            from mrcnn import utils
            import mrcnn.model as modellib
         
         
            sys.path.append(os.path.join(ROOT_DIR, "samples/coco/"))  # To find local version
            import coco
            
         
            MODEL_DIR = os.path.join(ROOT_DIR, "logs")
            COCO_MODEL_PATH = os.path.join(ROOT_DIR, "mask_rcnn_coco.h5")
            if not os.path.exists(COCO_MODEL_PATH):
                print('cannot find coco_model')
                
            class InferenceConfig(coco.CocoConfig):
                GPU_COUNT = 1
                IMAGES_PER_GPU = 1
         
            config = InferenceConfig()
            config.display()
            
            model = modellib.MaskRCNN(
                mode="inference", model_dir=MODEL_DIR, config=config
            )
         
            # Load weights trained on MS-COCO
            model.load_weights(COCO_MODEL_PATH, by_name=True)
            class_names = ['BG', 'person', 'bicycle', 'car', 'motorcycle', 'airplane',
                       'bus', 'train', 'truck', 'boat', 'traffic light',
                       'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird',
                       'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear',
                       'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
                       'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
                       'kite', 'baseball bat', 'baseball glove', 'skateboard',
                       'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
                       'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
                       'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
                       'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
                       'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
                       'keyboard', 'cell phone', 'microwave', 'oven', 'toaster',
                       'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors',
                       'teddy bear', 'hair drier', 'toothbrush']
            
            
            capture=cv2.VideoCapture(0)
            capture.set(cv2.CAP_PROP_FRAME_WIDTH,1920)
            capture.set(cv2.CAP_PROP_FRAME_HEIGHT,1080)
            
            while True:
                ret,frame=capture.read()
                results=model.detect([frame],verbose=0)
                r=results[0]
                
                
                frame=display_instances(
                      frame,r['rois'], r['masks'], r['class_ids'],
                                    class_names, r['scores']
                )
                
                cv2.imshow('frame',frame)
                if cv2.waitKey(1)&0xFF==ord('q'):
                    break
               
            capture.release()
            cv2.destroyAllWindows()


  • 相关阅读:
    java接口对入参的判断校验
    sqlyog使用技巧
    mysql 数据库的表中复制一条数据并添加到该表中
    union all ,union 注意事项,查询结果集中的字段名称顺序必须一致
    IDEA连接mysq数据库,其实很简单
    git版本回退、git远程分支管理、git本地分支管理、git生产代码bug修复
    Vue上拉加载下拉刷新---vue-easyrefresh
    Flutter上拉加载下拉刷新---flutter_easyrefresh
    vue-cli webpack多Html页面的配置(附框架vue-webpack-multipage实例)
    Qt使用镜像源快速安装与更新
  • 原文地址:https://www.cnblogs.com/shuimuqingyang/p/10102849.html
Copyright © 2011-2022 走看看