From e83ac944b6e380781ea42a714a55a33c075cc904 Mon Sep 17 00:00:00 2001 From: RangiLyu Date: Fri, 23 Dec 2022 15:46:29 +0800 Subject: [PATCH] [Feature] Registry supports import modules automatically. (#643) * [Feature] Support registry auto import modules. * update * rebase and fix ut * add docstring * remove count_registered_modules * update docstring * resolve comments * resolve comments * rename ut * fix warning * avoid BC breaking * update doc * update doc * resolve comments --- docs/en/advanced_tutorials/registry.md | 23 +++- docs/en/api/registry.rst | 1 + docs/zh_cn/advanced_tutorials/registry.md | 17 ++- docs/zh_cn/api/registry.rst | 1 + mmengine/registry/__init__.py | 5 +- mmengine/registry/registry.py | 145 +++++++++++++++------ mmengine/registry/utils.py | 24 ++++ mmengine/runner/runner.py | 8 +- tests/test_registry/test_registry_utils.py | 20 ++- 9 files changed, 182 insertions(+), 62 deletions(-) diff --git a/docs/en/advanced_tutorials/registry.md b/docs/en/advanced_tutorials/registry.md index eb875961de..bbaa454d36 100644 --- a/docs/en/advanced_tutorials/registry.md +++ b/docs/en/advanced_tutorials/registry.md @@ -18,16 +18,18 @@ There are three steps required to use the registry to manage modules in the code Suppose we want to implement a series of activation modules and want to be able to switch to different modules by just modifying the configuration without modifying the code. -Let's create a regitry first. +Let's create a registry first. ```python from mmengine import Registry -# scope represents the domain of the registry. If not set, the default value is the package name. +# `scope` represents the domain of the registry. If not set, the default value is the package name. # e.g. in mmdetection, the scope is mmdet -ACTIVATION = Registry('activation', scope='mmengine') +# `locations` indicates the location where the modules in this registry are defined. +# The Registry will automatically import the modules when building them according to these predefined locations. +ACTIVATION = Registry('activation', scope='mmengine', locations=['mmengine.models.activations']) ``` -Then we can implement different activation modules, such as `Sigmoid`, `ReLU`, and `Softmax`. +The module `mmengine.models.activations` specified by `locations` corresponds to the `mmengine/models/activations.py` file. When building modules with registry, the ACTIVATION registry will automatically import implemented modules from this file. Therefore, we can implement different activation layers in the `mmengine/models/activations.py` file, such as `Sigmoid`, `ReLU`, and `Softmax`. ```python import torch.nn as nn @@ -75,7 +77,14 @@ print(ACTIVATION.module_dict) ``` ```{note} -The registry mechanism will only be triggered when the corresponded module file is imported, so we need to import the file somewhere or dynamically import the module using the ``custom_imports`` field to trigger the mechanism. Please refer to [Importing custom Python modules](config.md#import-the-custom-module) for more details. +The key to trigger the registry mechanism is to make the module imported. +There are three ways to register a module into the registry + +1. Implement the module in the ``locations``. The registry will automatically import modules in the predefined locations. This is to ease the usage of algorithm libraries so that users can directly use ``REGISTRY.build(cfg)``. + +2. Import the file manually. This is common when developers implement a new module in/out side the algorithm library. + +3. Use ``custom_imports`` field in config. Please refer to [Importing custom Python modules](config.md#import-the-custom-module) for more details. ``` Once the implemented module is successfully registered, we can use the activation module in the configuration file. @@ -119,7 +128,7 @@ def build_activation(cfg, registry, *args, **kwargs): Pass the `buid_activation` to `build_func`. ```python -ACTIVATION = Registry('activation', build_func=build_activation, scope='mmengine') +ACTIVATION = Registry('activation', build_func=build_activation, scope='mmengine', locations=['mmengine.models.activations']) @ACTIVATION.register_module() class Tanh(nn.Module): @@ -206,7 +215,7 @@ Now suppose there is a project called `MMAlpha`, which also defines a `MODELS` a ```python from mmengine import Registry, MODELS as MMENGINE_MODELS -MODELS = Registry('model', parent=MMENGINE_MODELS, scope='mmalpha') +MODELS = Registry('model', parent=MMENGINE_MODELS, scope='mmalpha', locations=['mmalpha.models']) ``` The following figure shows the hierarchy of `MMEngine` and `MMAlpha`. diff --git a/docs/en/api/registry.rst b/docs/en/api/registry.rst index 84bbba8cc3..c05e38db4b 100644 --- a/docs/en/api/registry.rst +++ b/docs/en/api/registry.rst @@ -24,3 +24,4 @@ mmengine.registry build_scheduler_from_cfg count_registered_modules traverse_registry_tree + init_default_scope diff --git a/docs/zh_cn/advanced_tutorials/registry.md b/docs/zh_cn/advanced_tutorials/registry.md index 89a8de854b..2be9d6e527 100644 --- a/docs/zh_cn/advanced_tutorials/registry.md +++ b/docs/zh_cn/advanced_tutorials/registry.md @@ -23,10 +23,11 @@ MMEngine 实现的[注册器](mmengine.registry.Registry)可以看作一个映 ```python from mmengine import Registry # scope 表示注册器的作用域,如果不设置,默认为包名,例如在 mmdetection 中,它的 scope 为 mmdet -ACTIVATION = Registry('activation', scope='mmengine') +# locations 表示注册在此注册器的模块所存放的位置,注册器会根据预先定义的位置在构建模块时自动 import +ACTIVATION = Registry('activation', scope='mmengine', locations=['mmengine.models.activations']) ``` -然后我们可以实现不同的激活模块,例如 `Sigmoid`,`ReLU` 和 `Softmax`。 +`locations` 指定的模块 `mmengine.models.activations` 对应了 `mmengine/models/activations.py` 文件。在使用注册器构建模块的时候,ACTIVATION 注册器会自动从该文件中导入实现的模块。因此,我们可以在 `mmengine/models/activations.py` 文件中实现不同的激活函数,例如 `Sigmoid`,`ReLU` 和 `Softmax`。 ```python import torch.nn as nn @@ -74,7 +75,13 @@ print(ACTIVATION.module_dict) ``` ```{note} -只有模块所在的文件被导入时,注册机制才会被触发,所以我们需要在某处导入该文件或者使用 `custom_imports` 字段动态导入该模块进而触发注册机制,详情见[导入自定义 Python 模块](config.md#导入自定义-python-模块)。 +只有模块所在的文件被导入时,注册机制才会被触发,用户可以通过三种方式将模块添加到注册器中: + +1. 在 ``locations`` 指向的文件中实现模块。注册器将自动在预先定义的位置导入模块。这种方式是为了简化算法库的使用,以便用户可以直接使用 ``REGISTRY.build(cfg)``。 + +2. 手动导入文件。常用于用户在算法库之内或之外实现新的模块。 + +3. 在配置中使用 ``custom_imports`` 字段。 详情请参考[导入自定义Python模块](config.md#import-the-custom-module)。 ``` 模块成功注册后,我们可以通过配置文件使用这个激活模块。 @@ -119,7 +126,7 @@ def build_activation(cfg, registry, *args, **kwargs): 并将 `build_activation` 传递给 `build_func` 参数 ```python -ACTIVATION = Registry('activation', build_func=build_activation, scope='mmengine') +ACTIVATION = Registry('activation', build_func=build_activation, scope='mmengine', locations=['mmengine.models.activations']) @ACTIVATION.register_module() class Tanh(nn.Module): @@ -206,7 +213,7 @@ class RReLU(nn.Module): ```python from mmengine import Registry, MODELS as MMENGINE_MODELS -MODELS = Registry('model', parent=MMENGINE_MODELS, scope='mmalpha') +MODELS = Registry('model', parent=MMENGINE_MODELS, scope='mmalpha', locations=['mmalpha.models']) ``` 下图是 `MMEngine` 和 `MMAlpha` 的注册器层级结构。 diff --git a/docs/zh_cn/api/registry.rst b/docs/zh_cn/api/registry.rst index 84bbba8cc3..c05e38db4b 100644 --- a/docs/zh_cn/api/registry.rst +++ b/docs/zh_cn/api/registry.rst @@ -24,3 +24,4 @@ mmengine.registry build_scheduler_from_cfg count_registered_modules traverse_registry_tree + init_default_scope diff --git a/mmengine/registry/__init__.py b/mmengine/registry/__init__.py index 1e1de6bf67..de549a1d86 100644 --- a/mmengine/registry/__init__.py +++ b/mmengine/registry/__init__.py @@ -8,7 +8,8 @@ OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS, OPTIMIZERS, PARAM_SCHEDULERS, RUNNER_CONSTRUCTORS, RUNNERS, TASK_UTILS, TRANSFORMS, VISBACKENDS, VISUALIZERS, WEIGHT_INITIALIZERS) -from .utils import count_registered_modules, traverse_registry_tree +from .utils import (count_registered_modules, init_default_scope, + traverse_registry_tree) __all__ = [ 'Registry', 'RUNNERS', 'RUNNER_CONSTRUCTORS', 'HOOKS', 'DATASETS', @@ -18,5 +19,5 @@ 'VISBACKENDS', 'VISUALIZERS', 'LOG_PROCESSORS', 'EVALUATOR', 'DefaultScope', 'traverse_registry_tree', 'count_registered_modules', 'build_model_from_cfg', 'build_runner_from_cfg', 'build_from_cfg', - 'build_scheduler_from_cfg' + 'build_scheduler_from_cfg', 'init_default_scope' ] diff --git a/mmengine/registry/registry.py b/mmengine/registry/registry.py index 66bb4f75f7..c9d4a0aec8 100644 --- a/mmengine/registry/registry.py +++ b/mmengine/registry/registry.py @@ -32,6 +32,9 @@ class Registry: for children registry. If not specified, scope will be the name of the package where class is defined, e.g. mmdet, mmcls, mmseg. Defaults to None. + locations (list): The locations to import the modules registered + in this registry. Defaults to []. + New in version 0.4.0. Examples: >>> # define a registry @@ -54,6 +57,16 @@ class Registry: >>> pass >>> fasterrcnn = DETECTORS.build(dict(type='FasterRCNN')) + >>> # add locations to enable auto import + >>> DETECTORS = Registry('detectors', parent=MODELS, + >>> scope='det', locations=['det.models.detectors']) + >>> # define this class in 'det.models.detectors' + >>> @DETECTORS.register_module() + >>> class MaskRCNN: + >>> pass + >>> # The registry will auto import det.models.detectors.MaskRCNN + >>> fasterrcnn = DETECTORS.build(dict(type='det.MaskRCNN')) + More advanced usages can be found at https://mmengine.readthedocs.io/en/latest/tutorials/registry.html. """ @@ -62,11 +75,14 @@ def __init__(self, name: str, build_func: Optional[Callable] = None, parent: Optional['Registry'] = None, - scope: Optional[str] = None): + scope: Optional[str] = None, + locations: List = []): from .build_functions import build_from_cfg self._name = name self._module_dict: Dict[str, Type] = dict() self._children: Dict[str, 'Registry'] = dict() + self._locations = locations + self._imported = False if scope is not None: assert isinstance(scope, str) @@ -240,27 +256,25 @@ def switch_scope_and_registry(self, scope: str) -> Generator: # Get registry by scope if default_scope is not None: scope_name = default_scope.scope_name - if scope_name in PKG2PROJECT: - try: - module = import_module(f'{scope_name}.utils') - module.register_all_modules(False) # type: ignore - except (ImportError, AttributeError, ModuleNotFoundError): - if scope in PKG2PROJECT: - print_log( - f'{scope} is not installed and its ' - 'modules will not be registered. If you ' - 'want to use modules defined in ' - f'{scope}, Please install {scope} by ' - f'`pip install {PKG2PROJECT[scope]}.', - logger='current', - level=logging.WARNING) - else: - print_log( - f'Failed to import {scope} and register ' - 'its modules, please make sure you ' - 'have registered the module manually.', - logger='current', - level=logging.WARNING) + try: + import_module(f'{scope_name}.registry') + except (ImportError, AttributeError, ModuleNotFoundError): + if scope in PKG2PROJECT: + print_log( + f'{scope} is not installed and its ' + 'modules will not be registered. If you ' + 'want to use modules defined in ' + f'{scope}, Please install {scope} by ' + f'`pip install {PKG2PROJECT[scope]}.', + logger='current', + level=logging.WARNING) + else: + print_log( + f'Failed to import `{scope}.registry` ' + f'make sure the registry.py exists in `{scope}` ' + 'package.', + logger='current', + level=logging.WARNING) root = self._get_root_registry() registry = root._search_child(scope_name) if registry is None: @@ -290,6 +304,59 @@ def _get_root_registry(self) -> 'Registry': root = root.parent return root + def import_from_location(self) -> None: + """import modules from the pre-defined locations in self._location.""" + if not self._imported: + # Avoid circular import + from ..logging import print_log + + # avoid BC breaking + if len(self._locations) == 0 and self.scope in PKG2PROJECT: + print_log( + f'The "{self.name}" registry in {self.scope} did not ' + 'set import location. Fallback to call ' + f'`{self.scope}.utils.register_all_modules` ' + 'instead.', + logger='current', + level=logging.WARNING) + try: + module = import_module(f'{self.scope}.utils') + module.register_all_modules(False) # type: ignore + except (ImportError, AttributeError, ModuleNotFoundError): + if self.scope in PKG2PROJECT: + print_log( + f'{self.scope} is not installed and its ' + 'modules will not be registered. If you ' + 'want to use modules defined in ' + f'{self.scope}, Please install {self.scope} by ' + f'`pip install {PKG2PROJECT[self.scope]}.', + logger='current', + level=logging.WARNING) + else: + print_log( + f'Failed to import {self.scope} and register ' + 'its modules, please make sure you ' + 'have registered the module manually.', + logger='current', + level=logging.WARNING) + + for loc in self._locations: + try: + import_module(loc) + print_log( + f"Modules of {self.scope}'s {self.name} registry have " + f'been automatically imported from {loc}', + logger='current', + level=logging.DEBUG) + except (ImportError, AttributeError, ModuleNotFoundError): + print_log( + f'Failed to import {loc}, please check the ' + f'location of the registry {self.name} is ' + 'correct.', + logger='current', + level=logging.WARNING) + self._imported = True + def get(self, key: str) -> Optional[Type]: """Get the registry record. @@ -346,11 +413,14 @@ def get(self, key: str) -> Optional[Type]: obj_cls = None registry_name = self.name scope_name = self.scope + + # lazy import the modules to register them into the registry + self.import_from_location() + if scope is None or scope == self._scope: # get from self if real_key in self._module_dict: obj_cls = self._module_dict[real_key] - elif scope is None: # try to get the target from its parent or ancestors parent = self.parent @@ -362,24 +432,21 @@ def get(self, key: str) -> Optional[Type]: break parent = parent.parent else: + # import the registry to add the nodes into the registry tree try: - module = import_module(f'{scope}.utils') - module.register_all_modules(False) # type: ignore + import_module(f'{scope}.registry') + print_log( + f'Registry node of {scope} has been automatically ' + 'imported.', + logger='current', + level=logging.DEBUG) except (ImportError, AttributeError, ModuleNotFoundError): - if scope in PKG2PROJECT: - print_log( - f'{scope} is not installed and its modules ' - 'will not be registered. If you want to use ' - f'modules defined in {scope}, Please install ' - f'{scope} by `pip install {PKG2PROJECT[scope]} ', - logger='current', - level=logging.WARNING) - else: - print_log( - f'Failed to import "{scope}", and register its ' - f'modules. Please register {real_key} manually.', - logger='current', - level=logging.WARNING) + print_log( + f'Cannot auto import {scope}.registry, please check ' + f'whether the package "{scope}" is installed correctly ' + 'or import the registry manually.', + logger='current', + level=logging.DEBUG) # get from self._children if scope in self._children: obj_cls = self._children[scope].get(real_key) diff --git a/mmengine/registry/utils.py b/mmengine/registry/utils.py index 7ce11da155..19b1f8d23f 100644 --- a/mmengine/registry/utils.py +++ b/mmengine/registry/utils.py @@ -1,11 +1,13 @@ # Copyright (c) OpenMMLab. All rights reserved. import datetime import os.path as osp +import warnings from typing import Optional from mmengine.fileio import dump from mmengine.logging import print_log from . import root +from .default_scope import DefaultScope from .registry import Registry @@ -90,3 +92,25 @@ def count_registered_modules(save_path: Optional[str] = None, dump(scan_data, json_path, indent=2) print_log(f'Result has been saved to {json_path}', logger='current') return scan_data + + +def init_default_scope(scope: str) -> None: + """Initialize the given default scope. + + Args: + scope (str): The name of the default scope. + """ + never_created = DefaultScope.get_current_instance( + ) is None or not DefaultScope.check_instance_created(scope) + if never_created: + DefaultScope.get_instance(scope, scope_name=scope) + return + current_scope = DefaultScope.get_current_instance() # type: ignore + if current_scope.scope_name != scope: # type: ignore + warnings.warn('The current default scope ' # type: ignore + f'"{current_scope.scope_name}" is not "{scope}", ' + '`init_default_scope` will force set the current' + f'default scope to "{scope}".') + # avoid name conflict + new_instance_name = f'{scope}-{datetime.datetime.now()}' + DefaultScope.get_instance(new_instance_name, scope_name=scope) diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index edb404c7b3..b706bfc710 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -34,8 +34,7 @@ from mmengine.registry import (DATA_SAMPLERS, DATASETS, EVALUATOR, HOOKS, LOG_PROCESSORS, LOOPS, MODEL_WRAPPERS, MODELS, OPTIM_WRAPPERS, PARAM_SCHEDULERS, RUNNERS, - VISUALIZERS, DefaultScope, - count_registered_modules) + VISUALIZERS, DefaultScope) from mmengine.utils import digit_version, get_git_hash, is_seq_of from mmengine.utils.dl_utils import (TORCH_VERSION, collect_env, set_multi_processing) @@ -372,11 +371,6 @@ def __init__( # Collect and log environment information. self._log_env(env_cfg) - # collect information of all modules registered in the registries - registries_info = count_registered_modules( - self.work_dir if self.rank == 0 else None, verbose=False) - self.logger.debug(registries_info) - # Build `message_hub` for communication among components. # `message_hub` can store log scalars (loss, learning rate) and # runtime information (iter and epoch). Those components that do not diff --git a/tests/test_registry/test_registry_utils.py b/tests/test_registry/test_registry_utils.py index 35670435d7..457b033c9b 100644 --- a/tests/test_registry/test_registry_utils.py +++ b/tests/test_registry/test_registry_utils.py @@ -1,10 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. +import datetime import os.path as osp from tempfile import TemporaryDirectory from unittest import TestCase, skipIf -from mmengine.registry import (Registry, count_registered_modules, root, - traverse_registry_tree) +from mmengine.registry import (DefaultScope, Registry, + count_registered_modules, init_default_scope, + root, traverse_registry_tree) from mmengine.utils import is_installed @@ -62,3 +64,17 @@ def test_count_all_registered_modules(self): self.assertFalse( osp.exists( osp.join(temp_dir.name, 'modules_statistic_results.json'))) + + @skipIf(not is_installed('torch'), 'tests requires torch') + def test_init_default_scope(self): + # init default scope + init_default_scope('mmdet') + self.assertEqual(DefaultScope.get_current_instance().scope_name, + 'mmdet') + + # init default scope when another scope is init + name = f'test-{datetime.datetime.now()}' + DefaultScope.get_instance(name, scope_name='test') + with self.assertWarnsRegex( + Warning, 'The current default scope "test" is not "mmdet"'): + init_default_scope('mmdet')