Skip to content

Commit

Permalink
[Feature] Registry supports import modules automatically. (open-mmlab…
Browse files Browse the repository at this point in the history
…#643)

* [Feature] Support registry auto import modules.

* update

* rebase and fix ut

* add docstring

* remove count_registered_modules

* update docstring

* resolve comments

* resolve comments

* rename ut

* fix warning

* avoid BC breaking

* update doc

* update doc

* resolve comments
  • Loading branch information
RangiLyu authored Dec 23, 2022
1 parent 60492f4 commit e83ac94
Show file tree
Hide file tree
Showing 9 changed files with 182 additions and 62 deletions.
23 changes: 16 additions & 7 deletions docs/en/advanced_tutorials/registry.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,18 @@ There are three steps required to use the registry to manage modules in the code

Suppose we want to implement a series of activation modules and want to be able to switch to different modules by just modifying the configuration without modifying the code.

Let's create a regitry first.
Let's create a registry first.

```python
from mmengine import Registry
# scope represents the domain of the registry. If not set, the default value is the package name.
# `scope` represents the domain of the registry. If not set, the default value is the package name.
# e.g. in mmdetection, the scope is mmdet
ACTIVATION = Registry('activation', scope='mmengine')
# `locations` indicates the location where the modules in this registry are defined.
# The Registry will automatically import the modules when building them according to these predefined locations.
ACTIVATION = Registry('activation', scope='mmengine', locations=['mmengine.models.activations'])
```

Then we can implement different activation modules, such as `Sigmoid`, `ReLU`, and `Softmax`.
The module `mmengine.models.activations` specified by `locations` corresponds to the `mmengine/models/activations.py` file. When building modules with registry, the ACTIVATION registry will automatically import implemented modules from this file. Therefore, we can implement different activation layers in the `mmengine/models/activations.py` file, such as `Sigmoid`, `ReLU`, and `Softmax`.

```python
import torch.nn as nn
Expand Down Expand Up @@ -75,7 +77,14 @@ print(ACTIVATION.module_dict)
```

```{note}
The registry mechanism will only be triggered when the corresponded module file is imported, so we need to import the file somewhere or dynamically import the module using the ``custom_imports`` field to trigger the mechanism. Please refer to [Importing custom Python modules](config.md#import-the-custom-module) for more details.
The key to trigger the registry mechanism is to make the module imported.
There are three ways to register a module into the registry
1. Implement the module in the ``locations``. The registry will automatically import modules in the predefined locations. This is to ease the usage of algorithm libraries so that users can directly use ``REGISTRY.build(cfg)``.
2. Import the file manually. This is common when developers implement a new module in/out side the algorithm library.
3. Use ``custom_imports`` field in config. Please refer to [Importing custom Python modules](config.md#import-the-custom-module) for more details.
```

Once the implemented module is successfully registered, we can use the activation module in the configuration file.
Expand Down Expand Up @@ -119,7 +128,7 @@ def build_activation(cfg, registry, *args, **kwargs):
Pass the `buid_activation` to `build_func`.

```python
ACTIVATION = Registry('activation', build_func=build_activation, scope='mmengine')
ACTIVATION = Registry('activation', build_func=build_activation, scope='mmengine', locations=['mmengine.models.activations'])

@ACTIVATION.register_module()
class Tanh(nn.Module):
Expand Down Expand Up @@ -206,7 +215,7 @@ Now suppose there is a project called `MMAlpha`, which also defines a `MODELS` a
```python
from mmengine import Registry, MODELS as MMENGINE_MODELS

MODELS = Registry('model', parent=MMENGINE_MODELS, scope='mmalpha')
MODELS = Registry('model', parent=MMENGINE_MODELS, scope='mmalpha', locations=['mmalpha.models'])
```

The following figure shows the hierarchy of `MMEngine` and `MMAlpha`.
Expand Down
1 change: 1 addition & 0 deletions docs/en/api/registry.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@ mmengine.registry
build_scheduler_from_cfg
count_registered_modules
traverse_registry_tree
init_default_scope
17 changes: 12 additions & 5 deletions docs/zh_cn/advanced_tutorials/registry.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@ MMEngine 实现的[注册器](mmengine.registry.Registry)可以看作一个映
```python
from mmengine import Registry
# scope 表示注册器的作用域,如果不设置,默认为包名,例如在 mmdetection 中,它的 scope 为 mmdet
ACTIVATION = Registry('activation', scope='mmengine')
# locations 表示注册在此注册器的模块所存放的位置,注册器会根据预先定义的位置在构建模块时自动 import
ACTIVATION = Registry('activation', scope='mmengine', locations=['mmengine.models.activations'])
```

然后我们可以实现不同的激活模块,例如 `Sigmoid``ReLU``Softmax`
`locations` 指定的模块 `mmengine.models.activations` 对应了 `mmengine/models/activations.py` 文件。在使用注册器构建模块的时候,ACTIVATION 注册器会自动从该文件中导入实现的模块。因此,我们可以在 `mmengine/models/activations.py` 文件中实现不同的激活函数,例如 `Sigmoid``ReLU``Softmax`

```python
import torch.nn as nn
Expand Down Expand Up @@ -74,7 +75,13 @@ print(ACTIVATION.module_dict)
```

```{note}
只有模块所在的文件被导入时,注册机制才会被触发,所以我们需要在某处导入该文件或者使用 `custom_imports` 字段动态导入该模块进而触发注册机制,详情见[导入自定义 Python 模块](config.md#导入自定义-python-模块)。
只有模块所在的文件被导入时,注册机制才会被触发,用户可以通过三种方式将模块添加到注册器中:
1. 在 ``locations`` 指向的文件中实现模块。注册器将自动在预先定义的位置导入模块。这种方式是为了简化算法库的使用,以便用户可以直接使用 ``REGISTRY.build(cfg)``。
2. 手动导入文件。常用于用户在算法库之内或之外实现新的模块。
3. 在配置中使用 ``custom_imports`` 字段。 详情请参考[导入自定义Python模块](config.md#import-the-custom-module)。
```

模块成功注册后,我们可以通过配置文件使用这个激活模块。
Expand Down Expand Up @@ -119,7 +126,7 @@ def build_activation(cfg, registry, *args, **kwargs):
并将 `build_activation` 传递给 `build_func` 参数

```python
ACTIVATION = Registry('activation', build_func=build_activation, scope='mmengine')
ACTIVATION = Registry('activation', build_func=build_activation, scope='mmengine', locations=['mmengine.models.activations'])

@ACTIVATION.register_module()
class Tanh(nn.Module):
Expand Down Expand Up @@ -206,7 +213,7 @@ class RReLU(nn.Module):
```python
from mmengine import Registry, MODELS as MMENGINE_MODELS

MODELS = Registry('model', parent=MMENGINE_MODELS, scope='mmalpha')
MODELS = Registry('model', parent=MMENGINE_MODELS, scope='mmalpha', locations=['mmalpha.models'])
```

下图是 `MMEngine``MMAlpha` 的注册器层级结构。
Expand Down
1 change: 1 addition & 0 deletions docs/zh_cn/api/registry.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@ mmengine.registry
build_scheduler_from_cfg
count_registered_modules
traverse_registry_tree
init_default_scope
5 changes: 3 additions & 2 deletions mmengine/registry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS, OPTIMIZERS,
PARAM_SCHEDULERS, RUNNER_CONSTRUCTORS, RUNNERS, TASK_UTILS,
TRANSFORMS, VISBACKENDS, VISUALIZERS, WEIGHT_INITIALIZERS)
from .utils import count_registered_modules, traverse_registry_tree
from .utils import (count_registered_modules, init_default_scope,
traverse_registry_tree)

__all__ = [
'Registry', 'RUNNERS', 'RUNNER_CONSTRUCTORS', 'HOOKS', 'DATASETS',
Expand All @@ -18,5 +19,5 @@
'VISBACKENDS', 'VISUALIZERS', 'LOG_PROCESSORS', 'EVALUATOR',
'DefaultScope', 'traverse_registry_tree', 'count_registered_modules',
'build_model_from_cfg', 'build_runner_from_cfg', 'build_from_cfg',
'build_scheduler_from_cfg'
'build_scheduler_from_cfg', 'init_default_scope'
]
145 changes: 106 additions & 39 deletions mmengine/registry/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ class Registry:
for children registry. If not specified, scope will be the name of
the package where class is defined, e.g. mmdet, mmcls, mmseg.
Defaults to None.
locations (list): The locations to import the modules registered
in this registry. Defaults to [].
New in version 0.4.0.
Examples:
>>> # define a registry
Expand All @@ -54,6 +57,16 @@ class Registry:
>>> pass
>>> fasterrcnn = DETECTORS.build(dict(type='FasterRCNN'))
>>> # add locations to enable auto import
>>> DETECTORS = Registry('detectors', parent=MODELS,
>>> scope='det', locations=['det.models.detectors'])
>>> # define this class in 'det.models.detectors'
>>> @DETECTORS.register_module()
>>> class MaskRCNN:
>>> pass
>>> # The registry will auto import det.models.detectors.MaskRCNN
>>> fasterrcnn = DETECTORS.build(dict(type='det.MaskRCNN'))
More advanced usages can be found at
https://mmengine.readthedocs.io/en/latest/tutorials/registry.html.
"""
Expand All @@ -62,11 +75,14 @@ def __init__(self,
name: str,
build_func: Optional[Callable] = None,
parent: Optional['Registry'] = None,
scope: Optional[str] = None):
scope: Optional[str] = None,
locations: List = []):
from .build_functions import build_from_cfg
self._name = name
self._module_dict: Dict[str, Type] = dict()
self._children: Dict[str, 'Registry'] = dict()
self._locations = locations
self._imported = False

if scope is not None:
assert isinstance(scope, str)
Expand Down Expand Up @@ -240,27 +256,25 @@ def switch_scope_and_registry(self, scope: str) -> Generator:
# Get registry by scope
if default_scope is not None:
scope_name = default_scope.scope_name
if scope_name in PKG2PROJECT:
try:
module = import_module(f'{scope_name}.utils')
module.register_all_modules(False) # type: ignore
except (ImportError, AttributeError, ModuleNotFoundError):
if scope in PKG2PROJECT:
print_log(
f'{scope} is not installed and its '
'modules will not be registered. If you '
'want to use modules defined in '
f'{scope}, Please install {scope} by '
f'`pip install {PKG2PROJECT[scope]}.',
logger='current',
level=logging.WARNING)
else:
print_log(
f'Failed to import {scope} and register '
'its modules, please make sure you '
'have registered the module manually.',
logger='current',
level=logging.WARNING)
try:
import_module(f'{scope_name}.registry')
except (ImportError, AttributeError, ModuleNotFoundError):
if scope in PKG2PROJECT:
print_log(
f'{scope} is not installed and its '
'modules will not be registered. If you '
'want to use modules defined in '
f'{scope}, Please install {scope} by '
f'`pip install {PKG2PROJECT[scope]}.',
logger='current',
level=logging.WARNING)
else:
print_log(
f'Failed to import `{scope}.registry` '
f'make sure the registry.py exists in `{scope}` '
'package.',
logger='current',
level=logging.WARNING)
root = self._get_root_registry()
registry = root._search_child(scope_name)
if registry is None:
Expand Down Expand Up @@ -290,6 +304,59 @@ def _get_root_registry(self) -> 'Registry':
root = root.parent
return root

def import_from_location(self) -> None:
"""import modules from the pre-defined locations in self._location."""
if not self._imported:
# Avoid circular import
from ..logging import print_log

# avoid BC breaking
if len(self._locations) == 0 and self.scope in PKG2PROJECT:
print_log(
f'The "{self.name}" registry in {self.scope} did not '
'set import location. Fallback to call '
f'`{self.scope}.utils.register_all_modules` '
'instead.',
logger='current',
level=logging.WARNING)
try:
module = import_module(f'{self.scope}.utils')
module.register_all_modules(False) # type: ignore
except (ImportError, AttributeError, ModuleNotFoundError):
if self.scope in PKG2PROJECT:
print_log(
f'{self.scope} is not installed and its '
'modules will not be registered. If you '
'want to use modules defined in '
f'{self.scope}, Please install {self.scope} by '
f'`pip install {PKG2PROJECT[self.scope]}.',
logger='current',
level=logging.WARNING)
else:
print_log(
f'Failed to import {self.scope} and register '
'its modules, please make sure you '
'have registered the module manually.',
logger='current',
level=logging.WARNING)

for loc in self._locations:
try:
import_module(loc)
print_log(
f"Modules of {self.scope}'s {self.name} registry have "
f'been automatically imported from {loc}',
logger='current',
level=logging.DEBUG)
except (ImportError, AttributeError, ModuleNotFoundError):
print_log(
f'Failed to import {loc}, please check the '
f'location of the registry {self.name} is '
'correct.',
logger='current',
level=logging.WARNING)
self._imported = True

def get(self, key: str) -> Optional[Type]:
"""Get the registry record.
Expand Down Expand Up @@ -346,11 +413,14 @@ def get(self, key: str) -> Optional[Type]:
obj_cls = None
registry_name = self.name
scope_name = self.scope

# lazy import the modules to register them into the registry
self.import_from_location()

if scope is None or scope == self._scope:
# get from self
if real_key in self._module_dict:
obj_cls = self._module_dict[real_key]

elif scope is None:
# try to get the target from its parent or ancestors
parent = self.parent
Expand All @@ -362,24 +432,21 @@ def get(self, key: str) -> Optional[Type]:
break
parent = parent.parent
else:
# import the registry to add the nodes into the registry tree
try:
module = import_module(f'{scope}.utils')
module.register_all_modules(False) # type: ignore
import_module(f'{scope}.registry')
print_log(
f'Registry node of {scope} has been automatically '
'imported.',
logger='current',
level=logging.DEBUG)
except (ImportError, AttributeError, ModuleNotFoundError):
if scope in PKG2PROJECT:
print_log(
f'{scope} is not installed and its modules '
'will not be registered. If you want to use '
f'modules defined in {scope}, Please install '
f'{scope} by `pip install {PKG2PROJECT[scope]} ',
logger='current',
level=logging.WARNING)
else:
print_log(
f'Failed to import "{scope}", and register its '
f'modules. Please register {real_key} manually.',
logger='current',
level=logging.WARNING)
print_log(
f'Cannot auto import {scope}.registry, please check '
f'whether the package "{scope}" is installed correctly '
'or import the registry manually.',
logger='current',
level=logging.DEBUG)
# get from self._children
if scope in self._children:
obj_cls = self._children[scope].get(real_key)
Expand Down
24 changes: 24 additions & 0 deletions mmengine/registry/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.
import datetime
import os.path as osp
import warnings
from typing import Optional

from mmengine.fileio import dump
from mmengine.logging import print_log
from . import root
from .default_scope import DefaultScope
from .registry import Registry


Expand Down Expand Up @@ -90,3 +92,25 @@ def count_registered_modules(save_path: Optional[str] = None,
dump(scan_data, json_path, indent=2)
print_log(f'Result has been saved to {json_path}', logger='current')
return scan_data


def init_default_scope(scope: str) -> None:
"""Initialize the given default scope.
Args:
scope (str): The name of the default scope.
"""
never_created = DefaultScope.get_current_instance(
) is None or not DefaultScope.check_instance_created(scope)
if never_created:
DefaultScope.get_instance(scope, scope_name=scope)
return
current_scope = DefaultScope.get_current_instance() # type: ignore
if current_scope.scope_name != scope: # type: ignore
warnings.warn('The current default scope ' # type: ignore
f'"{current_scope.scope_name}" is not "{scope}", '
'`init_default_scope` will force set the current'
f'default scope to "{scope}".')
# avoid name conflict
new_instance_name = f'{scope}-{datetime.datetime.now()}'
DefaultScope.get_instance(new_instance_name, scope_name=scope)
8 changes: 1 addition & 7 deletions mmengine/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@
from mmengine.registry import (DATA_SAMPLERS, DATASETS, EVALUATOR, HOOKS,
LOG_PROCESSORS, LOOPS, MODEL_WRAPPERS, MODELS,
OPTIM_WRAPPERS, PARAM_SCHEDULERS, RUNNERS,
VISUALIZERS, DefaultScope,
count_registered_modules)
VISUALIZERS, DefaultScope)
from mmengine.utils import digit_version, get_git_hash, is_seq_of
from mmengine.utils.dl_utils import (TORCH_VERSION, collect_env,
set_multi_processing)
Expand Down Expand Up @@ -372,11 +371,6 @@ def __init__(
# Collect and log environment information.
self._log_env(env_cfg)

# collect information of all modules registered in the registries
registries_info = count_registered_modules(
self.work_dir if self.rank == 0 else None, verbose=False)
self.logger.debug(registries_info)

# Build `message_hub` for communication among components.
# `message_hub` can store log scalars (loss, learning rate) and
# runtime information (iter and epoch). Those components that do not
Expand Down
Loading

0 comments on commit e83ac94

Please sign in to comment.