zoukankan      html  css  js  c++  java
  • MaskRCNN路标:TensorFlow版本用于抠图

                MaskRCNN用于检测路标,作为更详细的目标检测,用以得到更精准的额路标位置,路标的几何中心点,用于构建更为精准的拓扑地图,减少构图误差。

                抠图工具已经完成,把框抠出来,用0值表示背景。


    python代码:

    def mainex():
    
        #initDir();
        # Root directory of the project
        ROOT_DIR = os.getcwd()
    
        # Directory to save logs and trained model
        MODEL_DIR = os.path.join(ROOT_DIR, "logs")
    
        # Path to trained weights file
        # Download this file and place in the root of your
        # project (See README file for details)
        #COCO_MODEL_PATH = os.path.join(ROOT_DIR, "mask_rcnn_coco.h5")
        COCO_MODEL_PATH= "D:/Works/PyProj/MaskRCNN-tensor/data/model/mask_rcnn_coco.h5";
    
        # Directory of images to run detection on
        #IMAGE_DIR = os.path.join(ROOT_DIR, "images");
        IMAGE_DIR = "data/MedSeaTest/";
    
        config = InferenceConfig()
        config.display();
    
        # 3.
        # Create model object in inference mode.
        model = modellib.MaskRCNN(mode="inference", model_dir=MODEL_DIR, config=config)
    
        # Load weights trained on MS-COCO
        model.load_weights(COCO_MODEL_PATH, by_name=True);
    
        # 4
        class_names= init_classname();
    
        IMAGE_DIR = "D:/DataSet/PicStyleTest/Medsea3/deskfilter/";
        proDir(model, class_names, IMAGE_DIR);

    处理目录:

    def proDir( model,class_names,IMAGE_DIR ):
        # Load a random image from the images folder
        print(IMAGE_DIR);
        
        extention =".jpg";
        filelist =traverseFolder( IMAGE_DIR , extention);#traverse( IMAGE_DIR , extention);#
    
        for file in filelist:
            print("Is processing: ");print(file);
            image = skimage.io.imread( file );
    
            # Run detection
            results = model.detect([image], verbose=1);
    
            # Visualize results
            #r = results[0];
            fileName = file;
            
            getAllLabelMask(fileName, image, results[0], class_names)

    def getAllLabelMask(fileName,image, maskResult,class_names ):
        """
        boxes: [num_instance, (y1, x1, y2, x2, class_id)] in image coordinates.
        masks: [num_instances, height, width]
        class_ids: [num_instances]
        class_names: list of class names of the dataset
        scores: (optional) confidence scores for each box
        figsize: (optional) the size of the image.
        """
        boxes  = maskResult['rois'];
        masks  = maskResult['masks']; 
        scores = maskResult['scores'];
        class_ids = maskResult['class_ids'];
        
        # Number of instances
        N = boxes.shape[0];
        if not( N<1 or boxes.shape[0] == masks.shape[-1] == class_ids.shape[0]):
            return
    
        row = image.shape[1];
        col = image.shape[0];
        for i in range(N):
            # Bounding box
            if not np.any(boxes[i]): 
                continue;
            y1, x1, y2, x2 = boxes[i];
    
            # Label
            class_id = class_ids[i];
            score = scores[i] if scores is not None else None
            label = class_names[class_id];
            
            # Mask
            mask = masks[:, :, i];
            masked_image = np.zeros((col, row, 3), dtype=np.uint8);
            masked_image = apply_maskX(masked_image, mask);
            
            #frontImage = np.zeros( (col, row), dtype=np.uint8 );
            frontImage = image.copy();
            for m in range(row):
                for n in range(col):
                    if(masked_image[n, m, 0]<254):
                        #frontImage[n, m] = 255;
                        frontImage[n,m,0] =0;
                        frontImage[n, m, 1] = 0;
                        frontImage[n, m, 2] = 0;
            #roiMask = masked_image[y1:y2, x1:x2];
            roiImg = frontImage[y1:y2, x1:x2];
            roiImg = cv2.cvtColor(roiImg, cv2.COLOR_BGR2RGB);
    
            fileMask = fileName[0: len(fileName)-4];
            fileMask = fileMask +"."+ str(i)+"."+label+"."+"Mask.png";
    
            cv2.imwrite(fileMask, roiImg);


    结果:








  • 相关阅读:
    angular ngIf指令 以及组件的输入输出
    angular 命令行指令总结
    angular8.x 事件的处理和样式绑定
    nodejs更新版本(windows)
    angular重要指令 ngFor
    emmet 常用总结
    手机真机调试 (ng项目)
    最长回文子串
    最长连续序列
    重复的子字符串
  • 原文地址:https://www.cnblogs.com/wishchin/p/9199882.html
Copyright © 2011-2022 走看看