Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Publish models after training if published_keys is set in CheckpointHook #987

Merged
merged 57 commits into from
Mar 29, 2023
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
672ee60
add publish keys in checkpointhook and update hook.md file
KerwinKai Mar 8, 2023
48e41fb
Update checkpoint_hook.py
KerwinKai Mar 8, 2023
104d478
Update hook.md
KerwinKai Mar 8, 2023
186b751
Update mmengine/hooks/checkpoint_hook.py
KerwinKai Mar 8, 2023
4638f32
Update mmengine/hooks/checkpoint_hook.py
KerwinKai Mar 8, 2023
0e8affc
Update mmengine/hooks/checkpoint_hook.py
KerwinKai Mar 8, 2023
6722a59
Update mmengine/hooks/checkpoint_hook.py
KerwinKai Mar 8, 2023
c08ecc5
Update mmengine/hooks/checkpoint_hook.py
KerwinKai Mar 8, 2023
e46f915
Update mmengine/hooks/checkpoint_hook.py
KerwinKai Mar 8, 2023
3cdb1c8
Update mmengine/hooks/checkpoint_hook.py
KerwinKai Mar 8, 2023
427bb58
Update mmengine/hooks/checkpoint_hook.py
KerwinKai Mar 8, 2023
6aab604
Update mmengine/hooks/checkpoint_hook.py
KerwinKai Mar 8, 2023
721f515
Update mmengine/hooks/checkpoint_hook.py
KerwinKai Mar 8, 2023
e4771fa
Update checkpoint_hook.py
KerwinKai Mar 8, 2023
bf7a601
Update docs/en/tutorials/hook.md
KerwinKai Mar 8, 2023
23a4b35
Merge branch 'open-mmlab:main' into KerwinKai/add_publish_keys
KerwinKai Mar 8, 2023
62c2495
Update mmengine/hooks/checkpoint_hook.py
KerwinKai Mar 8, 2023
1be2339
Update mmengine/hooks/checkpoint_hook.py
KerwinKai Mar 8, 2023
83f8fb6
Update mmengine/hooks/checkpoint_hook.py
KerwinKai Mar 8, 2023
176db6c
Update mmengine/hooks/checkpoint_hook.py
KerwinKai Mar 8, 2023
8c9961f
Update hook.md
KerwinKai Mar 8, 2023
c359530
Update mmengine/hooks/checkpoint_hook.py
KerwinKai Mar 8, 2023
599956a
Update checkpoint_hook.py
KerwinKai Mar 8, 2023
2c6528b
Update mmengine/hooks/checkpoint_hook.py
KerwinKai Mar 8, 2023
e60a232
Update checkpoint_hook.py
KerwinKai Mar 8, 2023
2351785
Update checkpoint_hook.py
KerwinKai Mar 8, 2023
f46e6b7
Update checkpoint_hook.py
KerwinKai Mar 8, 2023
78dcf4c
Update mmengine/hooks/checkpoint_hook.py
KerwinKai Mar 9, 2023
e452ff4
Update mmengine/hooks/checkpoint_hook.py
KerwinKai Mar 9, 2023
a165e81
Update checkpoint_hook.py
KerwinKai Mar 9, 2023
a47a9ac
Update mmengine/hooks/checkpoint_hook.py
KerwinKai Mar 9, 2023
d104f2f
Update mmengine/hooks/checkpoint_hook.py
KerwinKai Mar 9, 2023
4f90093
Merge branch 'open-mmlab:main' into KerwinKai/add_publish_keys
KerwinKai Mar 10, 2023
687b8eb
Update checkpoint_hook.py
KerwinKai Mar 10, 2023
362bdb3
Merge branch 'open-mmlab:main' into KerwinKai/add_publish_keys
KerwinKai Mar 13, 2023
e12b852
Update mmengine/hooks/checkpoint_hook.py
KerwinKai Mar 13, 2023
8a57047
Merge branch 'open-mmlab:main' into KerwinKai/add_publish_keys
KerwinKai Mar 13, 2023
268660b
Merge branch 'open-mmlab:main' into KerwinKai/add_publish_keys
KerwinKai Mar 15, 2023
55f14f6
Add Test for publish model
KerwinKai Mar 15, 2023
42640d6
Update checkpoint_hook.py
KerwinKai Mar 15, 2023
38af8c1
Update test_checkpoint_hook.py
KerwinKai Mar 15, 2023
40aafd5
Fix file to pass pre-commit check
KerwinKai Mar 15, 2023
b236741
Update mmengine/hooks/checkpoint_hook.py
KerwinKai Mar 17, 2023
3cda2ae
Merge branch 'open-mmlab:main' into KerwinKai/add_publish_keys
KerwinKai Mar 17, 2023
e0750f3
Fix mypy warning
KerwinKai Mar 17, 2023
912ffbb
rm not necessary line in checkpoint_hook.py
KerwinKai Mar 17, 2023
e2ae6b1
Update mmengine/hooks/checkpoint_hook.py
KerwinKai Mar 19, 2023
60d708a
Merge branch 'open-mmlab:main' into KerwinKai/add_publish_keys
KerwinKai Mar 20, 2023
551d84b
rm unnecessary messages add to message_hub
KerwinKai Mar 20, 2023
5a0009e
Update mmengine/hooks/checkpoint_hook.py
KerwinKai Mar 24, 2023
520d92a
Update docs/zh_cn/tutorials/hook.md
KerwinKai Mar 24, 2023
7c77367
Update docs/zh_cn/tutorials/hook.md
KerwinKai Mar 24, 2023
21178eb
Merge branch 'open-mmlab:main' into KerwinKai/add_publish_keys
KerwinKai Mar 27, 2023
b89a2ca
update checkpoint hook and hook.md file
KerwinKai Mar 27, 2023
85e1f59
Apply suggestions from code review
zhouzaida Mar 28, 2023
e833a00
Apply suggestions from code review
zhouzaida Mar 28, 2023
28499ad
Update mmengine/hooks/checkpoint_hook.py
zhouzaida Mar 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions docs/en/tutorials/hook.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ runner.train()
- Save the most recent checkpoints
- Save the best checkpoints
- Specify the path to save the checkpoints
- Automatically publish the best and the last checkpoints

For more features, please read the [CheckpointHook API documentation](mmengine.hooks.CheckpointHook).

Expand Down Expand Up @@ -120,6 +121,14 @@ The four features mentioned above are described below.
default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=5, out_dir='/path/of/directory'))
```

- Automatically publish the best and the last checkpoints

If you want to publish the best and the last checkpoints after training, you can set the `published_keys` parameter. You can select any keys in checkpoint to be published.

```python
default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=1, save_best='accuracy', rule='less', published_keys=['meta', 'state_dict']))
```

[LoggerHook](mmengine.hooks.LoggerHook) collects logs from different components of `Runner` and write them to terminal, JSON file, tensorboard and wandb .etc.

If we want to output (or save) the logs every 20 iterations, we can set the `interval` parameter and configure it as follows.
Expand Down
69 changes: 68 additions & 1 deletion mmengine/hooks/checkpoint_hook.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import logging
import subprocess
import torch
import os.path as osp
from math import inf
from pathlib import Path
Expand Down Expand Up @@ -84,7 +86,9 @@ class CheckpointHook(Hook):
backend_args (dict, optional): Arguments to instantiate the
prefix of uri corresponding backend. Defaults to None.
New in v0.2.0.

published_keys (str, List[str], optional): If ``save_last`` is ``True``
or ``save_best`` is not ``None``, it will automatically
publish model with keys in list after train. Defaults to None.
KerwinKai marked this conversation as resolved.
Show resolved Hide resolved
Examples:
>>> # Save best based on single metric
>>> CheckpointHook(interval=2, by_epoch=True, save_best='acc',
Expand All @@ -95,6 +99,9 @@ class CheckpointHook(Hook):
>>> # Save best based on multi metrics with different comparison rule
>>> CheckpointHook(interval=2, by_epoch=True,
>>> save_best=['FID', 'IS'], rule=['less', 'greater'])
>>> # Save best based on single metric and publish model after train
KerwinKai marked this conversation as resolved.
Show resolved Hide resolved
>>> CheckpointHook(interval=2, by_epoch=True, save_best='acc',
>>> rule='less', published_keys=['meta', 'state_dict'])
"""
out_dir: str

Expand Down Expand Up @@ -128,6 +135,7 @@ def __init__(self,
file_client_args: Optional[dict] = None,
filename_tmpl: Optional[str] = None,
backend_args: Optional[dict] = None,
published_keys: Union[str, List[str], None] = None,
**kwargs) -> None:
self.interval = interval
self.by_epoch = by_epoch
Expand Down Expand Up @@ -218,6 +226,21 @@ def __init__(self,
else:
self.best_ckpt_path_dict: Dict = dict()

# published keys
if not (isinstance(published_keys, str)
or is_list_of(published_keys, str)
or (published_keys is None)), (
'"published_keys" should be a str or list of str or None, '
f'but got {type(published_keys)}'):
raise TypeError(...)
KerwinKai marked this conversation as resolved.
Show resolved Hide resolved

if isinstace(published_keys, str):
published_keys = [published_keys]
elif isinstance(published_keys, list):
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
assert len(published_keys) == len(set(published_keys)), (
'Find duplicate element in "published_keys".')
KerwinKai marked this conversation as resolved.
Show resolved Hide resolved
self.published_keys = published_keys

def before_train(self, runner) -> None:
"""Finish all operations, related to checkpoint.

Expand Down Expand Up @@ -304,6 +327,50 @@ def after_val_epoch(self, runner, metrics):

self._save_best_checkpoint(runner, metrics)

def after_train(self, runner) -> None:
"""Publish the checkpoint after train epoch.
KerwinKai marked this conversation as resolved.
Show resolved Hide resolved

Args:
runner (Runner): The runner of the training process.
"""
if self.published_keys is None:
return

if self.save_last:
last_ckpt = runner.message_hub.get_info('last_ckpt')
assert last_ckpt ('Did not find last_checkpoint to be resumed.')
self._publish_model(runner, last_ckpt)
KerwinKai marked this conversation as resolved.
Show resolved Hide resolved

if self.save_best is not None:
if not self.best_ckpt_path:
raise RuntimeError(xxx)
self._publish_model(runner, best_ckpt)
KerwinKai marked this conversation as resolved.
Show resolved Hide resolved

def _publish_model(self, runner, out_file: str) -> None:

if published_keys is None:
return
KerwinKai marked this conversation as resolved.
Show resolved Hide resolved
checkpoint = runner.load_checkpoint(out_file)
KerwinKai marked this conversation as resolved.
Show resolved Hide resolved
ckpt_keys = list(checkpoint.keys())
published_keys = self.published_keys
KerwinKai marked this conversation as resolved.
Show resolved Hide resolved
for k in ckpt_keys:
if k not in published_keys:
print_log(
f'Key `{k}` will be removed because it is not in '
f'save_keys. If you want to keep it, '
f'please set `{k}` in published_keys',
logger='current')
checkpoint.pop(k)
KerwinKai marked this conversation as resolved.
Show resolved Hide resolved
out_file_name = osp.splitext(out_file)[0]
tmp_out_file_name = out_file + '.pth'
torch.save(checkpoint, tmp_out_file_name)
sha = subprocess.check_output(['sha256sum',
tmp_out_file_name]).decode()
KerwinKai marked this conversation as resolved.
Show resolved Hide resolved
final_file = out_file_name + f'-{sha[:8]}.pth'
subprocess.Popen(['mv', tmp_out_file_name, final_file])
KerwinKai marked this conversation as resolved.
Show resolved Hide resolved
print_log(
f'The published model is saved at {final_file}.', logger='current')

def _save_checkpoint(self, runner) -> None:
"""Save the current checkpoint and delete outdated checkpoint.

Expand Down