zoukankan      html  css  js  c++  java
  • mmdection的注册模块Registry解读

    理解mmdection模块基础:

    1. @ 的装饰器需理解(很重要)。

    2.class 类简单实列化等操作。

    简单应用类的装饰器的一个列子,理解注册模块:

    简单说明:注册模块实际是通过字典保存名字对应类地址,其中最重要的是Registry类中

    self._module_dict = dict()未定义的的字典中添加注册类的地址,以便于后续访问。
    而添加类对象地址通过该函数添加 _register_module(self, module_class)。
    如下FOOD实列中添加了Rice与Noodles类的地址。

    builder.py 文件代码如下:

    from registry import Registry, build_from_cfg

    FRUIT = Registry('fruit')
    FOOD = Registry('food')
    def build(cfg, registry, default_args=None):
    return build_from_cfg(cfg, registry, default_args)

    def build_fruit(cfg):
    return build(cfg, FRUIT)
    def build_food(cfg):
    return build(cfg, FOOD)
    @FOOD.register_module
    class Rice():
    def __init__(self, name):
    self.name = name
    @FOOD.register_module
    class Noodles(): # 实际FOOD实列化后的
    def __init__(self):
    pass
    @FRUIT.register_module
    class Apple():
    def __init__(self, name):
    self.name = name

    registry.py文件代码如下:
    import inspect
    import six

    def is_str(x):
    """Whether the input is an string instance."""
    return isinstance(x, six.string_types)


    class Registry(object):

    def __init__(self, name):
    self._name = name # 此处的self,是个对象(Object),是当前类的实例,name即为传进来的'detector'值
    self._module_dict = dict() # 定义的属性,是一个字典

    @property
    def name(self): # 把方法变成属性,通过self.name 就能获得name的值。我感觉是一个私有函数
    return self._name

    @property
    def module_dict(self):
    return self._module_dict

    def get(self, key):
    return self._module_dict.get(key, None)

    def _register_module(self, module_class):
    """
    关键的一个方法,作用就是Register a module.
    在model文件夹下的py文件中,里面的class定义上面都会出现 @DETECTORS.register_module,意思就是将类当做形参,
    将类送入了方法register_module()中执行。@的具体用法看后面解释。
    Register a module.

    Args:
    module (:obj:`nn.Module`): Module to be registered.
    """
    # if not inspect.isclass(module_class): # 判断是否为类,是类的话,就为True,否则报错
    # raise TypeError('module must be a class, but got {}'.format(
    # type(module_class)))
    module_name = module_class.__name__ # 获取类名
    if module_name in self._module_dict: # 看该类是否已经登记在属性_module_dict中
    raise KeyError('{} is already registered in {}'.format(
    module_name, self.name))
    self._module_dict[module_name] = module_class # 在module中dict新增key和value。key为类名,value为类对象

    def register_module(self, cls): # 对上面的方法,修改了名字,添加了返回值,即返回类本身
    self._register_module(cls)
    return cls

    def build_from_cfg(cfg, registry, default_args=None):
    """Build a module from config dict.

    Args:
    cfg (dict): Config dict. It should at least contain the key "type".
    registry (:obj:`Registry`): The registry to search the type from.
    default_args (dict, optional): Default initialization arguments.

    Returns:
    obj: The constructed object.
    """
    assert isinstance(cfg, dict) and 'type' in cfg
    assert isinstance(default_args, dict) or default_args is None
    args = cfg.copy()
    obj_type = args.pop('type')
    if is_str(obj_type):
    obj_cls = registry.get(obj_type)
    if obj_cls is None:
    raise KeyError('{} is not in the {} registry'.format(
    obj_type, registry.name))
    elif inspect.isclass(obj_type):
    obj_cls = obj_type
    else:
    raise TypeError('type must be a str or valid type, but got {}'.format(
    type(obj_type)))
    if default_args is not None:
    for name, value in default_args.items():
    args.setdefault(name, value)


    return obj_cls(**args)


    lunch.py文件代码如下:
    lunch=dict(
    food=dict(type='Rice', name='东北大米'),
    fruit=dict(type='Apple', name='青苹果')
    )

    demo.py文件代码如下:

    from build import build_fruit, build_food
    from lunch import lunch


    class COOKER():
    def __init__(self,food, fruit):
    print('今日饮食清单:{}, {}'.format(food, fruit))
    self.food = build_food(food)
    self.fruit = build_fruit(fruit)

    def run(self):
    print('具体饮食计划')
    print('主食吃: {}'.format(self.food.name))
    print('水果吃: {}'.format(self.fruit.name))

    cook = COOKER(**lunch)
    cook.run()

    结果如下:

    mmdection的注册使用:

    Registry.py文件用来添加模块注册机制的,实际是通过registry类建立的实列如BACKNONE实列中添加模块,如resnet等

    并通过self._module_dict 字典保存,该文件与上面的registry.py文件相同,不贴代码了。

    builders.py 文件用来读取config信息,并通过registry文件中build_from_cfg函数来保存信息,实际保存为类中初始化变量信息(参数),

    通过赋值调用注册模块中符合要求的类。

    from torch import nn

    from mmdet.utils import build_from_cfg
    from .registry import (BACKBONES, DETECTORS, HEADS, LOSSES, NECKS,
    ROI_EXTRACTORS, SHARED_HEADS)


    def build(cfg, registry, default_args=None):
    if isinstance(cfg, list):
    modules = [
    build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
    ]
    return nn.Sequential(*modules)
    else:
    return build_from_cfg(cfg, registry, default_args)
    # default_args 来自config的train_cfg与test_cfg


    def build_backbone(cfg):
    return build(cfg, BACKBONES)


    def build_neck(cfg):
    return build(cfg, NECKS)


    def build_roi_extractor(cfg):
    return build(cfg, ROI_EXTRACTORS)


    def build_shared_head(cfg):
    return build(cfg, SHARED_HEADS)


    def build_head(cfg):
    return build(cfg, HEADS)


    def build_loss(cfg):
    return build(cfg, LOSSES)


    def build_detector(cfg, train_cfg=None, test_cfg=None):
    return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))

    model/registry.py文件纯粹是用来实列化模块的,也是这样的实列化,通过@符号添加模块的

    from mmdet.utils import Registry

    BACKBONES = Registry('backbone')
    NECKS = Registry('neck')
    ROI_EXTRACTORS = Registry('roi_extractor')
    SHARED_HEADS = Registry('shared_head')
    HEADS = Registry('head')
    LOSSES = Registry('loss')
    DETECTORS = Registry('detector')

    模块添加就是通过@ ,如同上个列子中的FOOD添加模块一样,而mmdection添加如同下图。

    以上大致为mmdection代码注册模块。

  • 相关阅读:
    (转)正则表达式与Python(RE)模块
    (转)【面试】【MySQL常见问题总结】【03】
    (转)MySQL性能调优my.cnf详解
    (转)python logging模块
    (转)python collections模块详解
    mysql故障总结
    rocksdb 编译安装 日志
    c++11 gcc4.8.x安装
    Install ssdb-rocks on CentOS 6
    在Shell里面判断字符串是否为空
  • 原文地址:https://www.cnblogs.com/tangjunjun/p/14019009.html
Copyright © 2011-2022 走看看