Skip to content

Commit

Permalink
Move is_method_overriden to mmcv/utils/misc.py
Browse files Browse the repository at this point in the history
  • Loading branch information
mzr1996 committed Jun 25, 2021
1 parent 5087679 commit 7008c04
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 50 deletions.
3 changes: 1 addition & 2 deletions mmcv/runner/hooks/hook.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Copyright (c) Open-MMLab. All rights reserved.
from mmcv.runner.utils import is_method_overriden
from mmcv.utils import Registry
from mmcv.utils import Registry, is_method_overriden

HOOKS = Registry('hook')

Expand Down
16 changes: 0 additions & 16 deletions mmcv/runner/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,19 +79,3 @@ def set_random_seed(seed, deterministic=False, use_rank_shift=False):
if deterministic:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


def is_method_overriden(method, base_class, sub_class):
"""Check if a method of base class is overriden in sub class.
Args:
method (str): the method name to check.
base_class (type): the class of the base class.
sub_class (type | Any): the class or instance of the sub class.
"""
if not isinstance(sub_class, type):
sub_class = sub_class.__class__

base_method = getattr(base_class, method)
sub_method = getattr(sub_class, method)
return sub_method != base_method
14 changes: 8 additions & 6 deletions mmcv/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
# Copyright (c) Open-MMLab. All rights reserved.
from .config import Config, ConfigDict, DictAction
from .misc import (check_prerequisites, concat_list, deprecated_api_warning,
import_modules_from_strings, is_list_of, is_seq_of, is_str,
is_tuple_of, iter_cast, list_cast, requires_executable,
requires_package, slice_list, tuple_cast)
import_modules_from_strings, is_list_of,
is_method_overriden, is_seq_of, is_str, is_tuple_of,
iter_cast, list_cast, requires_executable, requires_package,
slice_list, tuple_cast)
from .path import (check_file_exist, fopen, is_filepath, mkdir_or_exist,
scandir, symlink)
from .progressbar import (ProgressBar, track_iter_progress,
Expand All @@ -29,17 +30,18 @@
'Timer', 'TimerError', 'check_time', 'deprecated_api_warning',
'digit_version', 'get_git_hash', 'import_modules_from_strings',
'assert_dict_contains_subset', 'assert_attrs_equal',
'assert_dict_has_keys', 'assert_keys_equal', 'check_python_script'
'assert_dict_has_keys', 'assert_keys_equal', 'check_python_script',
'is_method_overriden'
]
else:
from .env import collect_env
from .logging import get_logger, print_log
from .parrots_jit import jit, skip_no_elena
from .parrots_wrapper import (
CUDA_HOME, TORCH_VERSION, BuildExtension, CppExtension, CUDAExtension,
DataLoader, PoolDataLoader, SyncBatchNorm, _AdaptiveAvgPoolNd,
_AdaptiveMaxPoolNd, _AvgPoolNd, _BatchNorm, _ConvNd,
_ConvTransposeMixin, _InstanceNorm, _MaxPoolNd, get_build_config)
from .parrots_jit import jit, skip_no_elena
from .registry import Registry, build_from_cfg
__all__ = [
'Config', 'ConfigDict', 'DictAction', 'collect_env', 'get_logger',
Expand All @@ -58,5 +60,5 @@
'get_git_hash', 'import_modules_from_strings', 'jit', 'skip_no_elena',
'assert_dict_contains_subset', 'assert_attrs_equal',
'assert_dict_has_keys', 'assert_keys_equal', 'assert_is_norm_layer',
'assert_params_all_zeros', 'check_python_script'
'assert_params_all_zeros', 'check_python_script', 'is_method_overriden'
]
16 changes: 16 additions & 0 deletions mmcv/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,3 +313,19 @@ def new_func(*args, **kwargs):
return new_func

return api_warning_wrapper


def is_method_overriden(method, base_class, sub_class):
"""Check if a method of base class is overriden in sub class.
Args:
method (str): the method name to check.
base_class (type): the class of the base class.
sub_class (type | Any): the class or instance of the sub class.
"""
if not isinstance(sub_class, type):
sub_class = sub_class.__class__

base_method = getattr(base_class, method)
sub_method = getattr(sub_class, method)
return sub_method != base_method
27 changes: 1 addition & 26 deletions tests/test_runner/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
import torch

from mmcv.runner import is_method_overriden, set_random_seed
from mmcv.runner import set_random_seed


def test_set_random_seed():
Expand All @@ -26,28 +26,3 @@ def test_set_random_seed():
assert a_random == b_random
assert np.equal(a_np_random, b_np_random).all()
assert torch.equal(a_torch_random, b_torch_random)


def test_is_method_overriden():

class Base(object):

def foo1():
pass

def foo2():
pass

class Sub(Base):

def foo1():
pass

# test passing sub class directly
assert is_method_overriden('foo1', Base, Sub)
assert not is_method_overriden('foo2', Base, Sub)

# test passing instance of sub class
sub_instance = Sub()
assert is_method_overriden('foo1', Base, sub_instance)
assert not is_method_overriden('foo2', Base, sub_instance)
25 changes: 25 additions & 0 deletions tests/test_utils/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,28 @@ def test_import_modules_from_strings():
['os.path', '_not_implemented'], allow_failed_imports=True)
assert imported[0] == osp
assert imported[1] is None


def test_is_method_overriden():

class Base(object):

def foo1():
pass

def foo2():
pass

class Sub(Base):

def foo1():
pass

# test passing sub class directly
assert mmcv.is_method_overriden('foo1', Base, Sub)
assert not mmcv.is_method_overriden('foo2', Base, Sub)

# test passing instance of sub class
sub_instance = Sub()
assert mmcv.is_method_overriden('foo1', Base, sub_instance)
assert not mmcv.is_method_overriden('foo2', Base, sub_instance)

0 comments on commit 7008c04

Please sign in to comment.