From 7008c04a7a8a6437c7e150007b554228f2e726fc Mon Sep 17 00:00:00 2001 From: mzr1996 Date: Fri, 25 Jun 2021 14:19:31 +0800 Subject: [PATCH] Move `is_method_overriden` to `mmcv/utils/misc.py` --- mmcv/runner/hooks/hook.py | 3 +-- mmcv/runner/utils.py | 16 ---------------- mmcv/utils/__init__.py | 14 ++++++++------ mmcv/utils/misc.py | 16 ++++++++++++++++ tests/test_runner/test_utils.py | 27 +-------------------------- tests/test_utils/test_misc.py | 25 +++++++++++++++++++++++++ 6 files changed, 51 insertions(+), 50 deletions(-) diff --git a/mmcv/runner/hooks/hook.py b/mmcv/runner/hooks/hook.py index f9a138a85c7..46b18fefa90 100644 --- a/mmcv/runner/hooks/hook.py +++ b/mmcv/runner/hooks/hook.py @@ -1,6 +1,5 @@ # Copyright (c) Open-MMLab. All rights reserved. -from mmcv.runner.utils import is_method_overriden -from mmcv.utils import Registry +from mmcv.utils import Registry, is_method_overriden HOOKS = Registry('hook') diff --git a/mmcv/runner/utils.py b/mmcv/runner/utils.py index ac55d9ad73c..168305f0cd3 100644 --- a/mmcv/runner/utils.py +++ b/mmcv/runner/utils.py @@ -79,19 +79,3 @@ def set_random_seed(seed, deterministic=False, use_rank_shift=False): if deterministic: torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False - - -def is_method_overriden(method, base_class, sub_class): - """Check if a method of base class is overriden in sub class. - - Args: - method (str): the method name to check. - base_class (type): the class of the base class. - sub_class (type | Any): the class or instance of the sub class. - """ - if not isinstance(sub_class, type): - sub_class = sub_class.__class__ - - base_method = getattr(base_class, method) - sub_method = getattr(sub_class, method) - return sub_method != base_method diff --git a/mmcv/utils/__init__.py b/mmcv/utils/__init__.py index ba2a2c9e94b..2ec8ebfb7f6 100644 --- a/mmcv/utils/__init__.py +++ b/mmcv/utils/__init__.py @@ -2,9 +2,10 @@ # Copyright (c) Open-MMLab. All rights reserved. from .config import Config, ConfigDict, DictAction from .misc import (check_prerequisites, concat_list, deprecated_api_warning, - import_modules_from_strings, is_list_of, is_seq_of, is_str, - is_tuple_of, iter_cast, list_cast, requires_executable, - requires_package, slice_list, tuple_cast) + import_modules_from_strings, is_list_of, + is_method_overriden, is_seq_of, is_str, is_tuple_of, + iter_cast, list_cast, requires_executable, requires_package, + slice_list, tuple_cast) from .path import (check_file_exist, fopen, is_filepath, mkdir_or_exist, scandir, symlink) from .progressbar import (ProgressBar, track_iter_progress, @@ -29,17 +30,18 @@ 'Timer', 'TimerError', 'check_time', 'deprecated_api_warning', 'digit_version', 'get_git_hash', 'import_modules_from_strings', 'assert_dict_contains_subset', 'assert_attrs_equal', - 'assert_dict_has_keys', 'assert_keys_equal', 'check_python_script' + 'assert_dict_has_keys', 'assert_keys_equal', 'check_python_script', + 'is_method_overriden' ] else: from .env import collect_env from .logging import get_logger, print_log + from .parrots_jit import jit, skip_no_elena from .parrots_wrapper import ( CUDA_HOME, TORCH_VERSION, BuildExtension, CppExtension, CUDAExtension, DataLoader, PoolDataLoader, SyncBatchNorm, _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _BatchNorm, _ConvNd, _ConvTransposeMixin, _InstanceNorm, _MaxPoolNd, get_build_config) - from .parrots_jit import jit, skip_no_elena from .registry import Registry, build_from_cfg __all__ = [ 'Config', 'ConfigDict', 'DictAction', 'collect_env', 'get_logger', @@ -58,5 +60,5 @@ 'get_git_hash', 'import_modules_from_strings', 'jit', 'skip_no_elena', 'assert_dict_contains_subset', 'assert_attrs_equal', 'assert_dict_has_keys', 'assert_keys_equal', 'assert_is_norm_layer', - 'assert_params_all_zeros', 'check_python_script' + 'assert_params_all_zeros', 'check_python_script', 'is_method_overriden' ] diff --git a/mmcv/utils/misc.py b/mmcv/utils/misc.py index 5e4645e37d0..01e89ec29f4 100644 --- a/mmcv/utils/misc.py +++ b/mmcv/utils/misc.py @@ -313,3 +313,19 @@ def new_func(*args, **kwargs): return new_func return api_warning_wrapper + + +def is_method_overriden(method, base_class, sub_class): + """Check if a method of base class is overriden in sub class. + + Args: + method (str): the method name to check. + base_class (type): the class of the base class. + sub_class (type | Any): the class or instance of the sub class. + """ + if not isinstance(sub_class, type): + sub_class = sub_class.__class__ + + base_method = getattr(base_class, method) + sub_method = getattr(sub_class, method) + return sub_method != base_method diff --git a/tests/test_runner/test_utils.py b/tests/test_runner/test_utils.py index e2d8ec114a4..3983e80cd7e 100644 --- a/tests/test_runner/test_utils.py +++ b/tests/test_runner/test_utils.py @@ -4,7 +4,7 @@ import numpy as np import torch -from mmcv.runner import is_method_overriden, set_random_seed +from mmcv.runner import set_random_seed def test_set_random_seed(): @@ -26,28 +26,3 @@ def test_set_random_seed(): assert a_random == b_random assert np.equal(a_np_random, b_np_random).all() assert torch.equal(a_torch_random, b_torch_random) - - -def test_is_method_overriden(): - - class Base(object): - - def foo1(): - pass - - def foo2(): - pass - - class Sub(Base): - - def foo1(): - pass - - # test passing sub class directly - assert is_method_overriden('foo1', Base, Sub) - assert not is_method_overriden('foo2', Base, Sub) - - # test passing instance of sub class - sub_instance = Sub() - assert is_method_overriden('foo1', Base, sub_instance) - assert not is_method_overriden('foo2', Base, sub_instance) diff --git a/tests/test_utils/test_misc.py b/tests/test_utils/test_misc.py index adcd26ea0da..af179ea9faf 100644 --- a/tests/test_utils/test_misc.py +++ b/tests/test_utils/test_misc.py @@ -134,3 +134,28 @@ def test_import_modules_from_strings(): ['os.path', '_not_implemented'], allow_failed_imports=True) assert imported[0] == osp assert imported[1] is None + + +def test_is_method_overriden(): + + class Base(object): + + def foo1(): + pass + + def foo2(): + pass + + class Sub(Base): + + def foo1(): + pass + + # test passing sub class directly + assert mmcv.is_method_overriden('foo1', Base, Sub) + assert not mmcv.is_method_overriden('foo2', Base, Sub) + + # test passing instance of sub class + sub_instance = Sub() + assert mmcv.is_method_overriden('foo1', Base, sub_instance) + assert not mmcv.is_method_overriden('foo2', Base, sub_instance)