zoukankan      html  css  js  c++  java
  • mmdetection模型构建及Registry注册器机制

            好久没有做目标检测了,最近突然又接到了检测任务,跟同事讨论时,发现自己竟然忘了很多细节,

    于是想趁训练模型的间隙,重新梳理下目标检测。我选择了mmdetection来学习,除了目标检测本身,

    这个框架中很多python的使用技巧和框架的设计模式也是值得学习。最近一年基本都在使用python,

    希望能将这些技巧应用在以后的工作之中。mmdetection封装的很好方便使用,比如我想训练的

    话只需如下的一条指令。在train.py中,通过build_detector来构建模型(参数来自 faster_rcnn_r50_fpn_1x_voc0712.py),

    python tools/train.py  configs/pascal_voc/faster_rcnn_r50_fpn_1x_voc0712.py

    build_detector的定义如下,最后通过build_from_cfg来构建模型,这里看到了让人困惑的Registry.

    from mmdet.cv_core.utils import Registry, build_from_cfg
    from torch import nn
    
    BACKBONES = Registry('backbone')
    NECKS = Registry('neck')
    ROI_EXTRACTORS = Registry('roi_extractor')
    SHARED_HEADS = Registry('shared_head')
    HEADS = Registry('head')
    LOSSES = Registry('loss')
    DETECTORS = Registry('detector')
    
    def build(cfg, registry, default_args=None):
        """Build a module.
    
        Args:
            cfg (dict, list[dict]): The config of modules, is is either a dict
                or a list of configs.
            registry (:obj:`Registry`): A registry the module belongs to.
            default_args (dict, optional): Default arguments to build the module.
                Defaults to None.
    
        Returns:
            nn.Module: A built nn module.
        """
        if isinstance(cfg, list):
            modules = [
                build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
            ]
            return nn.Sequential(*modules)
        else:
            return build_from_cfg(cfg, registry, default_args)
    
    
    def build_detector(cfg, train_cfg=None, test_cfg=None):
        """Build detector."""
        return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))

           

    一、Registry是干什么的

            Registry完成了从字符串到类的映射,这样模型信息、训练时的参数信息,只需要写入到一个配置文件里,然后使用注册器来实例化即可。

    二、如何实现

            通过装饰器来实现。在mmcv/mmcv/registry.py中,我们看到了Registry类。其中完成字符串到类的映射,实际上就是下面的成员函数来实现的,核心代码就一句,将要注册的类添加到字典里,key为类的名字(字符串)。下面通过一个小例子,

     def _register_module(self, module_class, module_name=None, force=False):
            if not inspect.isclass(module_class):
                raise TypeError('module must be a class, '
                                f'but got {type(module_class)}')
    
            if module_name is None:
                module_name = module_class.__name__
            if not force and module_name in self._module_dict:
                raise KeyError(f'{module_name} is already registered '
                               f'in {self.name}')
            self._module_dict[module_name] = module_class

     来看看它的构建过程。在导入下面这个文件时,首先创建FRUIT实例,接着通过装饰器(这里是用成员函数装饰类)来注册Apple类,调用register_module,然后调用_register(注意:参数cls即为类Apple),最后调用_register_module完成Apple的添加。完成后,FRUIT就有了个字典成员:['Apple']=APPle。在build_from_cfg中,传入模型参数,即可通过FRUIT构建Apple的实例化对象。

    class Registry():
    
        def __init__(self, name):
            self._name = name
            self._module_dict = dict()
    
        def _register_module(self, module_class, module_name, force):
            self._module_dict[module_name] = module_class
    
    
        def register_module(self, name=None, force=False, module=None):
            print('register module ...')
            def _register(cls):
                print('cls ', cls)
                self._register_module(
                    module_class=cls, module_name=name, force=force)
                return cls
    
            return _register
    
    FRUIT = Registry('fruit')
    
    @FRUIT.register_module()
    class Apple():
        def __init__(self, name):
            self.name = name

    def build_from_cfg(cfg, registry, default_args=None):
       

        args = cfg.copy()

        if default_args is not None:
            for name, value in default_args.items():
                args.setdefault(name, value)

        obj_type = args.pop('type')
        if is_str(obj_type):
            obj_cls = registry.get(obj_type)
        
        return obj_cls(**args)

    三、Registry在mmdetection中是如何构建模型的

              我们来看一下构建模型的流程:

            1、在train.py中通过build_detector构建模型,其中cfg.model, cfg.train_cfg如下,包括模型信息和训练信息。

    model = build_detector(
            cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg) 

             

            2、最关键的部分来了。首先通过build_detector构建模型, 其中传入的DETECTORS是Registry的实例,在该实例中,包含了所有已经实现的检测器,如图。那么它是在哪里实现添加这些检测的类的呢?

    def build_detector(cfg, train_cfg=None, test_cfg=None):
        """Build detector."""
        return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))

               看了前面那个小例子我们就能猜到,一定是在这些检测类上,用Registry对其进行了注册,看看faster rcnn的实现,证明了我们的猜想。这样只要

    在定义这些类时,对其进行注册,那么就会自动加入到DETECTORS这个实例的成员字典里,非常的巧妙。当我们想实例化某个检测网络时,传入其字符名称

    即可。

           既然都看到这里了,就进一步看看网络时如何继续构建的吧。mmdetection将网络分成了几个部分,backbone,head,neck等。在TwoStageDetector(

    faster rcnn的基类)中,可以看到分别构建了这几个部分。head, neck, loss等,同样是通过Registry来注册实现的。最后就是将这几个部分组合起来即可。

    @DETECTORS.register_module()
    class TwoStageDetector(BaseDetector):
        """Base class for two-stage detectors.
    
        Two-stage detectors typically consisting of a region proposal network and a
        task-specific regression head.
        """
    
        def __init__(self,
                     backbone,
                     neck=None,
                     rpn_head=None,
                     roi_head=None,
                     train_cfg=None,
                     test_cfg=None,
                     pretrained=None):
            super(TwoStageDetector, self).__init__()
            self.backbone = build_backbone(backbone)
    
            if neck is not None:
                self.neck = build_neck(neck)
    
            if rpn_head is not None:
                rpn_train_cfg = train_cfg.rpn if train_cfg is not None else None
                rpn_head_ = rpn_head.copy()
                rpn_head_.update(train_cfg=rpn_train_cfg, test_cfg=test_cfg.rpn)
                self.rpn_head = build_head(rpn_head_)
    
            if roi_head is not None:
                # update train and test cfg here for now
                # TODO: refactor assigner & sampler
                rcnn_train_cfg = train_cfg.rcnn if train_cfg is not None else None
                roi_head.update(train_cfg=rcnn_train_cfg)
                roi_head.update(test_cfg=test_cfg.rcnn)
                self.roi_head = build_head(roi_head)
    
            self.train_cfg = train_cfg
            self.test_cfg = test_cfg
    
            self.init_weights(pretrained=pretrained)

    四、Registry的应用

              在我最近的一个数据处理的项目中,有三类数据,sample, measure 和image。如果我想得到某个数据类型的实例,我是通过if来

    判断的。那如果数据类别很多呢?就像检测器这样有几十种,再用if就显得很蠢了。借用Registry机制,可以轻松解决这个问题。

     

            

                  

          

           

          

  • 相关阅读:
    补充 函数详解
    Python web前端 11 form 和 ajax
    进程线程之间的通信
    面向对象epoll并发
    socket发送静态页面
    进程与线程的表示,属性,守护模式
    并发
    django, tornado
    并行
    非阻塞套接字编程, IO多路复用(epoll)
  • 原文地址:https://www.cnblogs.com/573177885qq/p/14307875.html
Copyright © 2011-2022 走看看