Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhance] Ensure metrics is not empty when saving best ckpts #849

Merged
merged 3 commits into from
Dec 28, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions mmengine/hooks/checkpoint_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,19 +294,19 @@ def after_val_epoch(self, runner, metrics):
runner (Runner): The runner of the training process.
metrics (dict): Evaluation results of all metrics
"""
if len(metrics) == 0:
runner.logger.warning(
'Since `metrics` is an empty dict, the behavior to save '
'the best checkpoint will be skipped in this evaluation.')
return

self._save_best_checkpoint(runner, metrics)

def _get_metric_score(self, metrics, key_indicator):
HAOCHENYE marked this conversation as resolved.
Show resolved Hide resolved
eval_res = OrderedDict()
if metrics is not None:
eval_res.update(metrics)

if len(eval_res) == 0:
warnings.warn(
'Since `eval_res` is an empty dict, the behavior to save '
'the best checkpoint will be skipped in this evaluation.')
return None

return eval_res[key_indicator]

def _save_checkpoint(self, runner) -> None:
Expand Down
142 changes: 68 additions & 74 deletions tests/test_hooks/test_checkpoint_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def test_after_val_epoch(self, tmp_path):
runner.work_dir = tmp_path
runner.epoch = 9
runner.model = Mock()
runner.logger.warning = Mock()
runner.message_hub = MessageHub.get_instance('test_after_val_epoch')

with pytest.raises(ValueError):
Expand All @@ -159,22 +160,11 @@ def test_after_val_epoch(self, tmp_path):
CheckpointHook(
interval=2, by_epoch=True, save_best='auto', rule='unsupport')

# if eval_res is an empty dict, print a warning information
with pytest.warns(UserWarning) as record_warnings:
eval_hook = CheckpointHook(
interval=2, by_epoch=True, save_best='auto')
eval_hook._get_metric_score(None, None)
# Since there will be many warnings thrown, we just need to check
# if the expected exceptions are thrown
expected_message = (
'Since `eval_res` is an empty dict, the behavior to '
'save the best checkpoint will be skipped in this '
'evaluation.')
for warning in record_warnings:
if str(warning.message) == expected_message:
break
else:
assert False
# if metrics is an empty dict, print a warning information
checkpoint_hook = CheckpointHook(
interval=2, by_epoch=True, save_best='auto')
checkpoint_hook.after_val_epoch(runner, {})
runner.logger.warning.assert_called_once()

# test error when number of rules and metrics are not same
with pytest.raises(AssertionError) as assert_error:
Expand All @@ -187,93 +177,97 @@ def test_after_val_epoch(self, tmp_path):
'"save_best", but got 3.')
assert error_message in str(assert_error.value)

# if save_best is None,no best_ckpt meta should be stored
eval_hook = CheckpointHook(interval=2, by_epoch=True, save_best=None)
eval_hook.before_train(runner)
eval_hook.after_val_epoch(runner, None)
# if save_best is None, no best_ckpt meta should be stored
checkpoint_hook = CheckpointHook(
interval=2, by_epoch=True, save_best=None)
checkpoint_hook.before_train(runner)
checkpoint_hook.after_val_epoch(runner, {})
assert 'best_score' not in runner.message_hub.runtime_info
assert 'best_ckpt' not in runner.message_hub.runtime_info

# when `save_best` is set to `auto`, first metric will be used.
metrics = {'acc': 0.5, 'map': 0.3}
eval_hook = CheckpointHook(interval=2, by_epoch=True, save_best='auto')
eval_hook.before_train(runner)
eval_hook.after_val_epoch(runner, metrics)
checkpoint_hook = CheckpointHook(
interval=2, by_epoch=True, save_best='auto')
checkpoint_hook.before_train(runner)
checkpoint_hook.after_val_epoch(runner, metrics)
best_ckpt_name = 'best_acc_epoch_9.pth'
best_ckpt_path = eval_hook.file_client.join_path(
eval_hook.out_dir, best_ckpt_name)
assert eval_hook.key_indicators == ['acc']
assert eval_hook.rules == ['greater']
best_ckpt_path = checkpoint_hook.file_client.join_path(
checkpoint_hook.out_dir, best_ckpt_name)
assert checkpoint_hook.key_indicators == ['acc']
assert checkpoint_hook.rules == ['greater']
assert 'best_score' in runner.message_hub.runtime_info and \
runner.message_hub.get_info('best_score') == 0.5
assert 'best_ckpt' in runner.message_hub.runtime_info and \
runner.message_hub.get_info('best_ckpt') == best_ckpt_path

# # when `save_best` is set to `acc`, it should update greater value
eval_hook = CheckpointHook(interval=2, by_epoch=True, save_best='acc')
eval_hook.before_train(runner)
checkpoint_hook = CheckpointHook(
interval=2, by_epoch=True, save_best='acc')
checkpoint_hook.before_train(runner)
metrics['acc'] = 0.8
eval_hook.after_val_epoch(runner, metrics)
checkpoint_hook.after_val_epoch(runner, metrics)
assert 'best_score' in runner.message_hub.runtime_info and \
runner.message_hub.get_info('best_score') == 0.8

# # when `save_best` is set to `loss`, it should update less value
eval_hook = CheckpointHook(interval=2, by_epoch=True, save_best='loss')
eval_hook.before_train(runner)
checkpoint_hook = CheckpointHook(
interval=2, by_epoch=True, save_best='loss')
checkpoint_hook.before_train(runner)
metrics['loss'] = 0.8
eval_hook.after_val_epoch(runner, metrics)
checkpoint_hook.after_val_epoch(runner, metrics)
metrics['loss'] = 0.5
eval_hook.after_val_epoch(runner, metrics)
checkpoint_hook.after_val_epoch(runner, metrics)
assert 'best_score' in runner.message_hub.runtime_info and \
runner.message_hub.get_info('best_score') == 0.5

# when `rule` is set to `less`,then it should update less value
# no matter what `save_best` is
eval_hook = CheckpointHook(
checkpoint_hook = CheckpointHook(
interval=2, by_epoch=True, save_best='acc', rule='less')
eval_hook.before_train(runner)
checkpoint_hook.before_train(runner)
metrics['acc'] = 0.3
eval_hook.after_val_epoch(runner, metrics)
checkpoint_hook.after_val_epoch(runner, metrics)
assert 'best_score' in runner.message_hub.runtime_info and \
runner.message_hub.get_info('best_score') == 0.3

# # when `rule` is set to `greater`,then it should update greater value
# # no matter what `save_best` is
eval_hook = CheckpointHook(
checkpoint_hook = CheckpointHook(
interval=2, by_epoch=True, save_best='loss', rule='greater')
eval_hook.before_train(runner)
checkpoint_hook.before_train(runner)
metrics['loss'] = 1.0
eval_hook.after_val_epoch(runner, metrics)
checkpoint_hook.after_val_epoch(runner, metrics)
assert 'best_score' in runner.message_hub.runtime_info and \
runner.message_hub.get_info('best_score') == 1.0

# test multi `save_best` with one rule
eval_hook = CheckpointHook(
checkpoint_hook = CheckpointHook(
interval=2, save_best=['acc', 'mIoU'], rule='greater')
assert eval_hook.key_indicators == ['acc', 'mIoU']
assert eval_hook.rules == ['greater', 'greater']
assert checkpoint_hook.key_indicators == ['acc', 'mIoU']
assert checkpoint_hook.rules == ['greater', 'greater']

# test multi `save_best` with multi rules
eval_hook = CheckpointHook(
checkpoint_hook = CheckpointHook(
interval=2, save_best=['FID', 'IS'], rule=['less', 'greater'])
assert eval_hook.key_indicators == ['FID', 'IS']
assert eval_hook.rules == ['less', 'greater']
assert checkpoint_hook.key_indicators == ['FID', 'IS']
assert checkpoint_hook.rules == ['less', 'greater']

# test multi `save_best` with default rule
eval_hook = CheckpointHook(interval=2, save_best=['acc', 'mIoU'])
assert eval_hook.key_indicators == ['acc', 'mIoU']
assert eval_hook.rules == ['greater', 'greater']
checkpoint_hook = CheckpointHook(interval=2, save_best=['acc', 'mIoU'])
assert checkpoint_hook.key_indicators == ['acc', 'mIoU']
assert checkpoint_hook.rules == ['greater', 'greater']
runner.message_hub = MessageHub.get_instance(
'test_after_val_epoch_save_multi_best')
eval_hook.before_train(runner)
checkpoint_hook.before_train(runner)
metrics = dict(acc=0.5, mIoU=0.6)
eval_hook.after_val_epoch(runner, metrics)
checkpoint_hook.after_val_epoch(runner, metrics)
best_acc_name = 'best_acc_epoch_9.pth'
best_acc_path = eval_hook.file_client.join_path(
eval_hook.out_dir, best_acc_name)
best_acc_path = checkpoint_hook.file_client.join_path(
checkpoint_hook.out_dir, best_acc_name)
best_mIoU_name = 'best_mIoU_epoch_9.pth'
best_mIoU_path = eval_hook.file_client.join_path(
eval_hook.out_dir, best_mIoU_name)
best_mIoU_path = checkpoint_hook.file_client.join_path(
checkpoint_hook.out_dir, best_mIoU_name)
assert 'best_score_acc' in runner.message_hub.runtime_info and \
runner.message_hub.get_info('best_score_acc') == 0.5
assert 'best_score_mIoU' in runner.message_hub.runtime_info and \
Expand All @@ -293,26 +287,26 @@ def test_after_val_epoch(self, tmp_path):

# check best ckpt name and best score
metrics = {'acc': 0.5, 'map': 0.3}
eval_hook = CheckpointHook(
checkpoint_hook = CheckpointHook(
interval=2, by_epoch=False, save_best='acc', rule='greater')
eval_hook.before_train(runner)
eval_hook.after_val_epoch(runner, metrics)
assert eval_hook.key_indicators == ['acc']
assert eval_hook.rules == ['greater']
checkpoint_hook.before_train(runner)
checkpoint_hook.after_val_epoch(runner, metrics)
assert checkpoint_hook.key_indicators == ['acc']
assert checkpoint_hook.rules == ['greater']
best_ckpt_name = 'best_acc_iter_9.pth'
best_ckpt_path = eval_hook.file_client.join_path(
eval_hook.out_dir, best_ckpt_name)
best_ckpt_path = checkpoint_hook.file_client.join_path(
checkpoint_hook.out_dir, best_ckpt_name)
assert 'best_ckpt' in runner.message_hub.runtime_info and \
runner.message_hub.get_info('best_ckpt') == best_ckpt_path
assert 'best_score' in runner.message_hub.runtime_info and \
runner.message_hub.get_info('best_score') == 0.5

# check best score updating
metrics['acc'] = 0.666
eval_hook.after_val_epoch(runner, metrics)
checkpoint_hook.after_val_epoch(runner, metrics)
best_ckpt_name = 'best_acc_iter_9.pth'
best_ckpt_path = eval_hook.file_client.join_path(
eval_hook.out_dir, best_ckpt_name)
best_ckpt_path = checkpoint_hook.file_client.join_path(
checkpoint_hook.out_dir, best_ckpt_name)
assert 'best_ckpt' in runner.message_hub.runtime_info and \
runner.message_hub.get_info('best_ckpt') == best_ckpt_path
assert 'best_score' in runner.message_hub.runtime_info and \
Expand All @@ -326,21 +320,21 @@ def test_after_val_epoch(self, tmp_path):
interval=2, save_best='acc', rule=['greater', 'less'])

# check best checkpoint name with `by_epoch` is False
eval_hook = CheckpointHook(
checkpoint_hook = CheckpointHook(
interval=2, by_epoch=False, save_best=['acc', 'mIoU'])
assert eval_hook.key_indicators == ['acc', 'mIoU']
assert eval_hook.rules == ['greater', 'greater']
assert checkpoint_hook.key_indicators == ['acc', 'mIoU']
assert checkpoint_hook.rules == ['greater', 'greater']
runner.message_hub = MessageHub.get_instance(
'test_after_val_epoch_save_multi_best_by_epoch_is_false')
eval_hook.before_train(runner)
checkpoint_hook.before_train(runner)
metrics = dict(acc=0.5, mIoU=0.6)
eval_hook.after_val_epoch(runner, metrics)
checkpoint_hook.after_val_epoch(runner, metrics)
best_acc_name = 'best_acc_iter_9.pth'
best_acc_path = eval_hook.file_client.join_path(
eval_hook.out_dir, best_acc_name)
best_acc_path = checkpoint_hook.file_client.join_path(
checkpoint_hook.out_dir, best_acc_name)
best_mIoU_name = 'best_mIoU_iter_9.pth'
best_mIoU_path = eval_hook.file_client.join_path(
eval_hook.out_dir, best_mIoU_name)
best_mIoU_path = checkpoint_hook.file_client.join_path(
checkpoint_hook.out_dir, best_mIoU_name)
assert 'best_score_acc' in runner.message_hub.runtime_info and \
runner.message_hub.get_info('best_score_acc') == 0.5
assert 'best_score_mIoU' in runner.message_hub.runtime_info and \
Expand Down