Skip to content

Commit

Permalink
[Enhance] Support config torch enviroment in config
Browse files Browse the repository at this point in the history
  • Loading branch information
HAOCHENYE committed Nov 7, 2022
1 parent 15fe131 commit 3e3fecb
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 40 deletions.
7 changes: 1 addition & 6 deletions mmengine/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,13 +191,8 @@ def revert_sync_batchnorm(module: nn.Module) -> nn.Module:
# qconfig exists in quantized models
if hasattr(module, 'qconfig'):
module_output.qconfig = module.qconfig

for name, child in module.named_children():
try:
module_output.add_module(name, revert_sync_batchnorm(child))
except Exception:
raise RuntimeError('Cannot revert ')

module_output.add_module(name, revert_sync_batchnorm(child))
del module
return module_output

Expand Down
16 changes: 9 additions & 7 deletions mmengine/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,15 +691,18 @@ def set_randomness(self,
diff_rank_seed=diff_rank_seed)

def set_torch_cfg(self, torch_cfg: dict) -> None:
"""_summary_
"""Set torch variable to by config.
Args:
torch_cfg (dict): _description_
torch_cfg (dict): Contains key-value pair which defines the names
of torch variables and corresponding values.
Raises:
ValueError: _description_
AttributeError: _description_
"""
Examples:
>>> torch_cfg = dict(
>>> backends.cuda.matmul.allow_tf32=True,
>>> torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction=True
>>> )
""" # noqa: E501
for attributes, value in torch_cfg.items():
assert isinstance(attributes, str)
if attributes in [
Expand All @@ -712,7 +715,6 @@ def set_torch_cfg(self, torch_cfg: dict) -> None:
attributes_list = attributes.split('.')
modules = attributes_list[:-1]
attribute = attributes_list[-1]

target_module = torch
try:
for module in modules:
Expand Down
55 changes: 28 additions & 27 deletions tests/test_runner/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,8 +329,32 @@ def __init__(self,
cfg=None):
pass

def setup_env(self, env_cfg):
pass
def test_setup_env(self):
# 1.Test torch_cfg.
# 1.1 Test valid torch_cfg
cfg = copy.deepcopy(self.epoch_based_cfg)
torch_cfg = {'backends.cuda.matmul.allow_tf32': True}
cfg.env_cfg.torch_cfg = torch_cfg
cfg.experiment_name = 'test_build_logger1'
runner = Runner.from_cfg(cfg)
self.assertTrue(torch.backends.cuda.matmul.allow_tf32, True)

torch_cfg['backends.cuda.matmul.allow_tf32'] = False
runner.setup_env(cfg.env_cfg)
self.assertFalse(torch.backends.cuda.matmul.allow_tf32, True)

# Test invalid torch_cfg
with self.assertRaisesRegex(ValueError,
'torch.backends.cuda.matmul.allow_tf31'):
cfg.env_cfg.torch_cfg = {
'torch.backends.cuda.matmul.allow_tf31': True
}
runner.setup_env(cfg.env_cfg)

# Test set torch_cfg with randomness_cfg
with self.assertRaisesRegex(ValueError, 'backends.cudnn.benchmark'):
cfg.env_cfg.torch_cfg = {'backends.cudnn.benchmark': True}
runner.setup_env(cfg.env_cfg)


@EVALUATOR.register_module()
Expand Down Expand Up @@ -705,31 +729,8 @@ def test_from_cfg(self):
self.assertIsInstance(runner, Runner)

def test_setup_env(self):
# 1.Test torch_cfg.
# 1.1 Test valid torch_cfg
cfg = copy.deepcopy(self.epoch_based_cfg)
torch_cfg = {'backends.cuda.matmul.allow_tf32': True}
cfg.env_cfg.torch_cfg = torch_cfg
cfg.experiment_name = 'test_build_logger1'
runner = Runner.from_cfg(cfg)
self.assertTrue(torch.backends.cuda.matmul.allow_tf32, True)

torch_cfg['backends.cuda.matmul.allow_tf32'] = False
runner.setup_env(cfg.env_cfg)
self.assertFalse(torch.backends.cuda.matmul.allow_tf32, True)

# Test invalid torch_cfg
with self.assertRaisesRegex(ValueError,
'torch.backends.cuda.matmul.allow_tf31'):
cfg.env_cfg.torch_cfg = {
'torch.backends.cuda.matmul.allow_tf31': True
}
runner.setup_env(cfg.env_cfg)

# Test set torch_cfg with randomness_cfg
with self.assertRaisesRegex(ValueError, 'backends.cudnn.benchmark'):
cfg.env_cfg.torch_cfg = {'backends.cudnn.benchmark': True}
runner.setup_env(cfg.env_cfg)
# TODO
pass

def test_build_logger(self):
self.epoch_based_cfg.experiment_name = 'test_build_logger1'
Expand Down

0 comments on commit 3e3fecb

Please sign in to comment.