zoukankan      html  css  js  c++  java
  • maskrcnn_benchmark代码分析(1)

    使用ipdb调试

    try:
        import ipdb
    except:
        import pdb as ipdb
    
    ipdb.set_trace()

    测试inference:

    # coding=utf-8
    
    import matplotlib.pyplot as plt
    import matplotlib.pylab as pylab
    
    import requests
    from io import BytesIO
    from PIL import Image
    import numpy as np
    
    # this makes our figures bigger
    pylab.rcParams['figure.figsize'] = 20, 12
    
    from maskrcnn_benchmark.config import cfg
    from predictor import COCODemo
    
    
    config_file = "../configs/caffe2/e2e_mask_rcnn_R_50_FPN_1x_caffe2.yaml"
    #config_file = "../configs/e2e_mask_rcnn_R_50_FPN_1x.yaml"
    
    # update the config options with the config file
    cfg.merge_from_file(config_file)
    # manual override some options
    cfg.merge_from_list(["MODEL.DEVICE", "cuda"]) # only "cuda" and "cpu" are valid device types
    coco_demo = COCODemo(
        cfg,
        min_image_size=800,
        confidence_threshold=0.7,
    )
    
    def load(url):
        """
        Given an url of an image, downloads the image and
        returns a PIL image
        """
        response = requests.get(url)
        pil_image = Image.open(BytesIO(response.content)).convert("RGB")
        # convert to BGR format
        image = np.array(pil_image)[:, :, [2, 1, 0]]
        return image
    
    def imshow(img):
        plt.imshow(img[:, :, [2, 1, 0]])
        plt.axis("off")
        plt.show()
    
    # from http://cocodataset.org/#explore?id=345434
    image = load("http://farm3.staticflickr.com/2469/3915380994_2e611b1779_z.jpg")
    # image = Image.open("474797538.jpg").convert("RGB")
    # image = np.array(image)[:, :, [2, 1, 0]]
    
    #imshow(image)
    
    
    
    # compute predictions
    predictions = coco_demo.run_on_opencv_image(image)
    imshow(predictions)
    View Code

    在predictor.py文件中核心函数def compute_prediction(self, original_image):下的变量信息:

    ->输入original_image=[480,640,3],int整型数据;

    ->经过变换后image=[3,800,1066],数据torch.float32;

    然后进入核心函数:predictions = self.model(image_list),跳入generalized_rcnn.py文件,中def forward(self, images, targets=None):函数;

    经过features = self.backbone(images.tensors)函数,使用各种基网络(如ResNet-50_FPN)提取各个stage的特征图;然后使用feature map进行RPN及ROI pooling操作;

    -> features变量信息,tuple类型,5个特征图的tensor:

    ipdb> p features.size()
    *** AttributeError: 'tuple' object has no attribute 'size'
    ipdb> p features.shape()
    *** AttributeError: 'tuple' object has no attribute 'shape'
    ipdb> p features[0].shape()
    *** TypeError: 'torch.Size' object is not callable
    ipdb> p features[0].size()
    torch.Size([1, 256, 200, 272])
    ipdb> p features[1].size()
    torch.Size([1, 256, 100, 136])
    ipdb> p features[2].size()
    torch.Size([1, 256, 50, 68])
    ipdb> p features[3].size()
    torch.Size([1, 256, 25, 34])
    ipdb> p features[4].size()
    torch.Size([1, 256, 13, 17])

    ->经过rpn网络得到候选框:proposals, proposal_losses = self.rpn(images, features, targets)

    ipdb> targets
    ipdb> p targets
    None
    ipdb> p images
    <maskrcnn_benchmark.structures.image_list.ImageList object at 0x7f5128049f28>
    ipdb> p proposal_losses
    {}
    ipdb> p proposals
    [BoxList(num_boxes=1000, image_width=1066, image_height=800, mode=xyxy)]

    -> 然后经过fast rcnn网络,x, result, detector_losses = self.roi_heads(features, proposals, targets); 这部分有在roi_heads.py文件中,由两分支组成:检测分支和分割分支组成;

    ->在roi_heads.py文件的forward()中:x, detections, loss_box = self.box(features, proposals, targets)得到检测结果,

        def forward(self, features, proposals, targets=None):
            """
            Arguments:
                features (list[Tensor]): feature-maps from possibly several levels
                proposals (list[BoxList]): proposal boxes
                targets (list[BoxList], optional): the ground-truth targets.
    
            Returns:
                x (Tensor): the result of the feature extractor
                proposals (list[BoxList]): during training, the subsampled proposals
                    are returned. During testing, the predicted boxlists are returned
                losses (dict[Tensor]): During training, returns the losses for the
                    head. During testing, returns an empty dict.
            """
    
            if self.training:
                # Faster R-CNN subsamples during training the proposals with a fixed
                # positive / negative ratio
                with torch.no_grad():
                    proposals = self.loss_evaluator.subsample(proposals, targets)
    
            # extract features that will be fed to the final classifier. The
            # feature_extractor generally corresponds to the pooler + heads
            x = self.feature_extractor(features, proposals)
            # final classifier that converts the features into predictions
            class_logits, box_regression = self.predictor(x)
    
            if not self.training:
                result = self.post_processor((class_logits, box_regression), proposals)
                return x, result, {}
    
            loss_classifier, loss_box_reg = self.loss_evaluator(
                [class_logits], [box_regression]
            )
            return (
                x,
                proposals,
                dict(loss_classifier=loss_classifier, loss_box_reg=loss_box_reg),
            )

    ->x为经过池化操作及特征提取的特征用于分类回归,经过后处理,剩下有用的box返回;

    ->筛选出来的1000个proposals,提取1024维特征; 最终有效box剩88个;

    ipdb> x.shape
    torch.Size([1000, 1024])
    ipdb> detections.shape
    *** AttributeError: 'list' object has no attribute 'shape'
    ipdb> detections.size()
    *** AttributeError: 'list' object has no attribute 'size'
    ipdb> len(detections)
    1
    ipdb> detections
    [BoxList(num_boxes=88, image_width=1066, image_height=800, mode=xyxy)]

    -> 利用检测的结果,经过mask分支:x, detections, loss_mask = self.mask(mask_features, detections, targets); mask分支:

        def forward(self, features, proposals, targets=None):
            """
            Arguments:
                features (list[Tensor]): feature-maps from possibly several levels
                proposals (list[BoxList]): proposal boxes
                targets (list[BoxList], optional): the ground-truth targets.
    
            Returns:
                x (Tensor): the result of the feature extractor
                proposals (list[BoxList]): during training, the original proposals
                    are returned. During testing, the predicted boxlists are returned
                    with the `mask` field set
                losses (dict[Tensor]): During training, returns the losses for the
                    head. During testing, returns an empty dict.
            """
    
            if self.training:
                # during training, only focus on positive boxes
                all_proposals = proposals
                proposals, positive_inds = keep_only_positive_boxes(proposals)
            if self.training and self.cfg.MODEL.ROI_MASK_HEAD.SHARE_BOX_FEATURE_EXTRACTOR:
                x = features
                x = x[torch.cat(positive_inds, dim=0)]
            else:
                x = self.feature_extractor(features, proposals)
            mask_logits = self.predictor(x)
    
            if not self.training:
                result = self.post_processor(mask_logits, proposals)
                return x, result, {}
    
            loss_mask = self.loss_evaluator(proposals, mask_logits, targets)
    
            return x, all_proposals, dict(loss_mask=loss_mask)

     ->x为maks分支特征的tensor,变成[88, 256, 14, 14],返回的detections就是box+mask的内容

    ipdb> x.shape
    torch.Size([88, 256, 14, 14])
    ipdb> detections
    [BoxList(num_boxes=88, image_width=1066, image_height=800, mode=xyxy)]
    ipdb> loss_mask
    {}

    ->做完后,返回generalized_rcnn.py文件,返回predictor.py进行一些后处理,可视化结果即可!

  • 相关阅读:
    Mysql命令行查看数据库大小(数据库版本为5.7以上)
    三大语言实例 (python,C/C++,Java)
    git ssh创建秘钥
    Git 安装和使用教程
    Windows sql语句正则匹配导出数据到本地 The MySQL server is running with the --secure-file-priv option so it cannot execute this statement
    sql语句语句中的正则查找
    触宝 求子串问题
    Java中10个流对象重点掌握
    Java I/O流
    Java 增强 for 循环
  • 原文地址:https://www.cnblogs.com/ranjiewen/p/10001590.html
Copyright © 2011-2022 走看看