diff --git a/mmengine/config/config.py b/mmengine/config/config.py index ae5f275486..5a9f706e4b 100644 --- a/mmengine/config/config.py +++ b/mmengine/config/config.py @@ -404,9 +404,7 @@ def _file2dict(filename: str, if len(duplicate_keys) > 0: raise KeyError('Duplicate key is not allowed among bases. ' f'Duplicate keys: {duplicate_keys}') - _cfg_dict = Config._dict_to_config_dict(_cfg_dict) - if scope is not None: - Config._parse_scope(_cfg_dict, scope) + _cfg_dict = Config._dict_to_config_dict(_cfg_dict, scope) base_cfg_dict.update(_cfg_dict) if filename.endswith('.py'): @@ -423,6 +421,7 @@ def _file2dict(filename: str, if isinstance(value, (types.FunctionType, types.ModuleType)): cfg_dict.pop(key) temp_config_file.close() + Config._parse_scope(cfg_dict) # check deprecation information if DEPRECATION_KEY in cfg_dict: @@ -460,43 +459,54 @@ def _file2dict(filename: str, return cfg_dict, cfg_text @staticmethod - def _dict_to_config_dict(cfg: dict): + def _dict_to_config_dict(cfg: dict, + scope: Optional[str] = None, + has_scope=True): """Recursively converts ``dict`` to :obj:`ConfigDict`. Args: cfg (dict): Config dict. + scope (str, optional): Scope of instance. + has_scope (bool): Whether to add `_scope_` key to config dict. Returns: ConfigDict: Converted dict. """ + # Only the outer dict with key `type` should have the key `_scope_`. if isinstance(cfg, dict): + if has_scope and 'type' in cfg: + has_scope = False + cfg._scope_ = scope # type: ignore cfg = ConfigDict(cfg) + dict.__setattr__(cfg, 'scope', scope) for key, value in cfg.items(): - cfg[key] = Config._dict_to_config_dict(value) + cfg[key] = Config._dict_to_config_dict( + value, scope=scope, has_scope=has_scope) elif isinstance(cfg, tuple): - cfg = tuple(Config._dict_to_config_dict(_cfg) for _cfg in cfg) + cfg = tuple( + Config._dict_to_config_dict(_cfg, scope, has_scope=has_scope) + for _cfg in cfg) elif isinstance(cfg, list): - cfg = [Config._dict_to_config_dict(_cfg) for _cfg in cfg] + cfg = [ + Config._dict_to_config_dict(_cfg, scope, has_scope=has_scope) + for _cfg in cfg + ] return cfg @staticmethod - def _parse_scope(cfg: dict, scope: str) -> None: - """Recursively add scope to config dict containing ``type`` field. + def _parse_scope(cfg: dict) -> None: + """Adds ``_scope_`` to :obj:`ConfigDict` instance. If the config dict already has the scope, scope will not be overwritten. Args: cfg (dict): Config needs to be parsed with scope. - scope (str): scope of external package. """ - if isinstance(cfg, dict): - if 'type' in cfg and '_scope_' not in cfg: - cfg['_scope_'] = scope - for value in cfg.values(): - Config._parse_scope(value, scope) + if isinstance(cfg, ConfigDict): + cfg._scope_ = cfg.scope elif isinstance(cfg, (tuple, list)): - [Config._parse_scope(value, scope) for value in cfg] + [Config._parse_scope(value) for value in cfg] else: return diff --git a/mmengine/registry/registry.py b/mmengine/registry/registry.py index 66e474e958..2d73e2df75 100644 --- a/mmengine/registry/registry.py +++ b/mmengine/registry/registry.py @@ -1,10 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. import inspect +import logging import sys -import warnings from collections.abc import Callable +from contextlib import contextmanager from importlib import import_module -from typing import Any, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Dict, Generator, List, Optional, Tuple, Type, Union from ..config import Config, ConfigDict from ..utils import ManagerMixin, is_installed, is_seq_of @@ -33,7 +34,7 @@ def build_runner_from_cfg(cfg: Union[dict, ConfigDict, Config], Returns: object: The constructed runner object. """ - from ..logging.logger import MMLogger + from ..logging import print_log assert isinstance( cfg, @@ -44,38 +45,43 @@ def build_runner_from_cfg(cfg: Union[dict, ConfigDict, Config], f'but got {type(registry)}') args = cfg.copy() - obj_type = args.pop('runner_type', 'mmengine.Runner') - if isinstance(obj_type, str): - runner_cls = registry.get(obj_type) - if runner_cls is None: - raise KeyError( - f'{obj_type} is not in the {registry.name} registry. ' - f'Please check whether the value of `{obj_type}` is correct or' - ' it was registered as expected. More details can be found at' - ' https://mmengine.readthedocs.io/en/latest/tutorials/config.html#import-custom-python-modules' # noqa: E501 - ) - elif inspect.isclass(obj_type): - runner_cls = obj_type - else: - raise TypeError( - f'type must be a str or valid type, but got {type(obj_type)}') - - try: - runner = runner_cls.from_cfg(args) # type: ignore - logger: MMLogger = MMLogger.get_current_instance() - logger.info( - f'An `{runner_cls.__name__}` instance is built ' # type: ignore - f'from registry, its implementation can be found in' - f'{runner_cls.__module__}') # type: ignore - return runner - - except Exception as e: - # Normal TypeError does not print class name. - cls_location = '/'.join( - runner_cls.__module__.split('.')) # type: ignore - raise type(e)( - f'class `{runner_cls.__name__}` in ' # type: ignore - f'{cls_location}.py: {e}') + # Runner should be built under target scope, if `_scope_` is defined + # in cfg, current default scope should switch to specified scope + # temporarily. + scope = args.pop('_scope_', None) + with registry.get_registry_by_scope(scope) as registry: + obj_type = args.pop('runner_type', 'mmengine.Runner') + if isinstance(obj_type, str): + runner_cls = registry.get(obj_type) + if runner_cls is None: + raise KeyError( + f'{obj_type} is not in the {registry.name} registry. ' + f'Please check whether the value of `{obj_type}` is ' + 'correct or it was registered as expected. More details ' + 'can be found at https://mmengine.readthedocs.io/en/latest/tutorials/config.html#import-custom-python-modules' # noqa: E501 + ) + elif inspect.isclass(obj_type): + runner_cls = obj_type + else: + raise TypeError( + f'type must be a str or valid type, but got {type(obj_type)}') + + try: + runner = runner_cls.from_cfg(args) # type: ignore + print_log( + f'An `{runner_cls.__name__}` instance is built from ' # type: ignore # noqa: E501 + 'registry, its implementation can be found in' + f'{runner_cls.__module__}', # type: ignore + logger='current') + return runner + + except Exception as e: + # Normal TypeError does not print class name. + cls_location = '/'.join( + runner_cls.__module__.split('.')) # type: ignore + raise type(e)( + f'class `{runner_cls.__name__}` in ' # type: ignore + f'{cls_location}.py: {e}') def build_model_from_cfg(cfg, registry, default_args=None): @@ -83,20 +89,20 @@ def build_model_from_cfg(cfg, registry, default_args=None): ``build_from_cfg``, if cfg is a list, a ``nn.Sequential`` will be built. Args: - cfg (dict, list[dict]): The config of modules, is is either a config - dict or a list of config dicts. If cfg is a list, a - the built modules will be wrapped with ``nn.Sequential``. + cfg (dict, list[dict]): The config of modules, which is either a config + dict or a list of config dicts. If cfg is a list, the built + modules will be wrapped with ``nn.Sequential``. registry (:obj:`Registry`): A registry the module belongs to. default_args (dict, optional): Default arguments to build the module. Defaults to None. Returns: - nn.Module: A built nn module. + nn.Module: A built nn.Module. """ from ..model import Sequential if isinstance(cfg, list): modules = [ - build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg + build_from_cfg(_cfg, registry, default_args) for _cfg in cfg ] return Sequential(*modules) else: @@ -150,7 +156,7 @@ def build_from_cfg( object: The constructed object. """ # Avoid circular import - from ..logging.logger import MMLogger + from ..logging import print_log if not isinstance(cfg, (dict, ConfigDict, Config)): raise TypeError( @@ -180,49 +186,8 @@ def build_from_cfg( # Instance should be built under target scope, if `_scope_` is defined # in cfg, current default scope should switch to specified scope # temporarily. - scope = args.get('_scope_', None) - with DefaultScope.overwrite_default_scope(scope): - # get the global default scope - default_scope = DefaultScope.get_current_instance() - if default_scope is not None: - scope_name = default_scope.scope_name - # Check installed external repo. - from mmengine.config.utils import PKG2PROJECT - if scope_name in PKG2PROJECT: - is_installed(PKG2PROJECT[scope_name]) - # TODO replace with import from. - module = import_module(f'{scope_name}.utils') - module.register_all_modules() # type: ignore # noqa: E501 - - root = registry.get_root_registry() - _registry = root.search_child(scope_name) - if _registry is None: - # if `default_scope` can not be found, fallback to use self - warnings.warn( - f'Failed to search registry with scope "{scope_name}" ' - f'in the "{root.name}" registry tree. ' - f'As a workaround, the current "{registry.name}" registry ' - f'in "{registry.scope}" is used to build instance. This ' - f'may cause unexpected failure when running the built ' - f'modules. Please check whether "{scope_name}" is a ' - f'correct scope, or whether the registry is ' - f'initialized.') - else: - registry = _registry - - # remove _scope_ defined in cfg. - def _remove_scope(cfg_dict, scope): - if isinstance(cfg_dict, dict): - _scope_ = cfg_dict.get('_scope_', None) - if _scope_ == scope: - cfg_dict.pop('_scope_') - [_remove_scope(_value, scope) for _value in cfg_dict.values()] - elif isinstance(cfg_dict, (list, tuple)): - [_remove_scope(_value, scope) for _value in cfg_dict] - - if '_scope_' in args: - _remove_scope(args, scope) - + scope = args.pop('_scope_', None) + with registry.get_registry_by_scope(scope) as registry: obj_type = args.pop('type') if isinstance(obj_type, str): obj_cls = registry.get(obj_type) @@ -249,11 +214,11 @@ def _remove_scope(cfg_dict, scope): else: obj = obj_cls(**args) # type: ignore - logger: MMLogger = MMLogger.get_current_instance() - logger.info( + print_log( f'An `{obj_cls.__name__}` instance is built from ' # type: ignore # noqa: E501 - f'registry, its implementation can be found in ' - f'{obj_cls.__module__}') # type: ignore + 'registry, its implementation can be found in ' + f'{obj_cls.__module__}', # type: ignore + logger='current') return obj except Exception as e: @@ -428,9 +393,58 @@ def children(self): @property def root(self): - return self.get_root_registry() + return self._get_root_registry() + + @contextmanager + def get_registry_by_scope(self, scope: str) -> Generator: + """Get the corresponding registry of the target scope. + + If the registry of the corresponding scope exists, return the + registry, otherwise return the current itself. + + Args: + scope (str): The target scope. + """ + # Switch to the given scope temporarily. If the corresponding registry + # can be found in root registry, return the registry under the scope, + # otherwise return the registry itself. + from mmengine.config.utils import PKG2PROJECT + from ..logging import print_log + + with DefaultScope.overwrite_default_scope(scope): + # Get the global default scope + default_scope = DefaultScope.get_current_instance() + # Get registry by scope + if default_scope is not None: + scope_name = default_scope.scope_name + if scope_name in PKG2PROJECT: + is_installed(PKG2PROJECT[scope_name]) + # TODO replace with import from. + module = import_module(f'{scope_name}.utils') + module.register_all_modules() # type: ignore + root = self._get_root_registry() + registry = root._search_child(scope_name) + if registry is None: + # if `default_scope` can not be found, fallback to argument + # `registry` + print_log( + f'Failed to search registry with scope "{scope_name}" ' + f'in the "{root.name}" registry tree. ' + f'As a workaround, the current "{self.name}" registry ' + f'in "{self.scope}" is used to build instance. This ' + 'may cause unexpected failure when running the built ' + f'modules. Please check whether "{scope_name}" is a ' + 'correct scope, or whether the registry is ' + 'initialized.', + logger='current', + level=logging.WARNING) + registry = self + # If there is no built default scope, just return current registry. + else: + registry = self + yield registry - def get_root_registry(self) -> 'Registry': + def _get_root_registry(self) -> 'Registry': """Return the root registry.""" root = self while root.parent is not None: @@ -515,7 +529,7 @@ def get(self, key: str) -> Optional[Type]: registry_name = self._children[scope].name scope_name = scope else: - root = self.get_root_registry() + root = self._get_root_registry() if scope != root._scope and scope not in root._children: # If not skip directly, `root.get(key)` will recursively @@ -531,7 +545,7 @@ def get(self, key: str) -> Optional[Type]: f' registry in "{scope_name}"') return obj_cls - def search_child(self, scope: str) -> Optional['Registry']: + def _search_child(self, scope: str) -> Optional['Registry']: """Depth-first search for the corresponding registry in its children. Note that the method only search for the corresponding registry from @@ -551,7 +565,7 @@ def search_child(self, scope: str) -> Optional['Registry']: return self for child in self._children.values(): - registry = child.search_child(scope) + registry = child._search_child(scope) if registry is not None: return registry diff --git a/tests/data/config/py_config/test_get_external_cfg3.py b/tests/data/config/py_config/test_get_external_cfg3.py index a74fd5d498..4161127e75 100644 --- a/tests/data/config/py_config/test_get_external_cfg3.py +++ b/tests/data/config/py_config/test_get_external_cfg3.py @@ -7,3 +7,11 @@ ] custom_hooks = [dict(type='mmdet.DetVisualizationHook')] + +model = dict( + roi_head=dict( + bbox_head=dict( + loss_cls=dict(_delete_=True, type='test.ToyLoss') + ) + ) +) \ No newline at end of file diff --git a/tests/test_config/test_config.py b/tests/test_config/test_config.py index 89b35d5ace..f04072d486 100644 --- a/tests/test_config/test_config.py +++ b/tests/test_config/test_config.py @@ -13,7 +13,7 @@ from mmengine import Config, ConfigDict, DictAction from mmengine.fileio import dump, load from mmengine.registry import MODELS, DefaultScope, Registry -from mmengine.utils import get_installed_path, is_installed +from mmengine.utils import is_installed class TestConfig: @@ -724,14 +724,14 @@ def test_copy(self): def test_get_external_cfg(self): ext_cfg_path = osp.join(self.data_path, 'config/py_config/test_get_external_cfg.py') - package_path = get_installed_path('mmdet') - cfg_path = osp.join( - package_path, '.mim', - 'configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py') ext_cfg = Config.fromfile(ext_cfg_path) - cfg = Config.fromfile(cfg_path) - Config._parse_scope(cfg, 'mmdet') - assert cfg._cfg_dict == ext_cfg._cfg_dict + assert ext_cfg._cfg_dict.model.neck == dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + num_outs=5, + ) + assert '_scope_' in ext_cfg._cfg_dict.model @pytest.mark.skipif( not is_installed('mmdet'), reason='mmdet should be installed') diff --git a/tests/test_registry/test_registry.py b/tests/test_registry/test_registry.py index 08a20b3c20..56d67e3831 100644 --- a/tests/test_registry/test_registry.py +++ b/tests/test_registry/test_registry.py @@ -196,7 +196,7 @@ def _build_registry(self): return registries - def test_get_root_registry(self): + def test__get_root_registry(self): # Hierarchical Registry # DOGS # _______|_______ @@ -209,10 +209,10 @@ def test_get_root_registry(self): registries = self._build_registry() DOGS, HOUNDS, LITTLE_HOUNDS, MID_HOUNDS = registries[:4] - assert DOGS.get_root_registry() is DOGS - assert HOUNDS.get_root_registry() is DOGS - assert LITTLE_HOUNDS.get_root_registry() is DOGS - assert MID_HOUNDS.get_root_registry() is DOGS + assert DOGS._get_root_registry() is DOGS + assert HOUNDS._get_root_registry() is DOGS + assert LITTLE_HOUNDS._get_root_registry() is DOGS + assert MID_HOUNDS._get_root_registry() is DOGS def test_get(self): # Hierarchical Registry @@ -307,7 +307,7 @@ class LittlePedigreeSamoyed: assert DOGS.get('samoyed.LittlePedigreeSamoyed') is None assert LITTLE_HOUNDS.get('mid_hound.PedigreeSamoyedddddd') is None - def test_search_child(self): + def test__search_child(self): # Hierarchical Registry # DOGS # _______|_______ @@ -320,11 +320,11 @@ def test_search_child(self): registries = self._build_registry() DOGS, HOUNDS, LITTLE_HOUNDS = registries[:3] - assert DOGS.search_child('hound') is HOUNDS - assert DOGS.search_child('not a child') is None - assert DOGS.search_child('little_hound') is LITTLE_HOUNDS - assert LITTLE_HOUNDS.search_child('hound') is None - assert LITTLE_HOUNDS.search_child('mid_hound') is None + assert DOGS._search_child('hound') is HOUNDS + assert DOGS._search_child('not a child') is None + assert DOGS._search_child('little_hound') is LITTLE_HOUNDS + assert LITTLE_HOUNDS._search_child('hound') is None + assert LITTLE_HOUNDS._search_child('mid_hound') is None @pytest.mark.parametrize('cfg_type', [dict, ConfigDict, Config]) def test_build(self, cfg_type):