Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Publish models after training if published_keys is set in CheckpointHook #987

Merged
merged 57 commits into from
Mar 29, 2023
Merged
Show file tree
Hide file tree
Changes from 55 commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
672ee60
add publish keys in checkpointhook and update hook.md file
KerwinKai Mar 8, 2023
48e41fb
Update checkpoint_hook.py
KerwinKai Mar 8, 2023
104d478
Update hook.md
KerwinKai Mar 8, 2023
186b751
Update mmengine/hooks/checkpoint_hook.py
KerwinKai Mar 8, 2023
4638f32
Update mmengine/hooks/checkpoint_hook.py
KerwinKai Mar 8, 2023
0e8affc
Update mmengine/hooks/checkpoint_hook.py
KerwinKai Mar 8, 2023
6722a59
Update mmengine/hooks/checkpoint_hook.py
KerwinKai Mar 8, 2023
c08ecc5
Update mmengine/hooks/checkpoint_hook.py
KerwinKai Mar 8, 2023
e46f915
Update mmengine/hooks/checkpoint_hook.py
KerwinKai Mar 8, 2023
3cdb1c8
Update mmengine/hooks/checkpoint_hook.py
KerwinKai Mar 8, 2023
427bb58
Update mmengine/hooks/checkpoint_hook.py
KerwinKai Mar 8, 2023
6aab604
Update mmengine/hooks/checkpoint_hook.py
KerwinKai Mar 8, 2023
721f515
Update mmengine/hooks/checkpoint_hook.py
KerwinKai Mar 8, 2023
e4771fa
Update checkpoint_hook.py
KerwinKai Mar 8, 2023
bf7a601
Update docs/en/tutorials/hook.md
KerwinKai Mar 8, 2023
23a4b35
Merge branch 'open-mmlab:main' into KerwinKai/add_publish_keys
KerwinKai Mar 8, 2023
62c2495
Update mmengine/hooks/checkpoint_hook.py
KerwinKai Mar 8, 2023
1be2339
Update mmengine/hooks/checkpoint_hook.py
KerwinKai Mar 8, 2023
83f8fb6
Update mmengine/hooks/checkpoint_hook.py
KerwinKai Mar 8, 2023
176db6c
Update mmengine/hooks/checkpoint_hook.py
KerwinKai Mar 8, 2023
8c9961f
Update hook.md
KerwinKai Mar 8, 2023
c359530
Update mmengine/hooks/checkpoint_hook.py
KerwinKai Mar 8, 2023
599956a
Update checkpoint_hook.py
KerwinKai Mar 8, 2023
2c6528b
Update mmengine/hooks/checkpoint_hook.py
KerwinKai Mar 8, 2023
e60a232
Update checkpoint_hook.py
KerwinKai Mar 8, 2023
2351785
Update checkpoint_hook.py
KerwinKai Mar 8, 2023
f46e6b7
Update checkpoint_hook.py
KerwinKai Mar 8, 2023
78dcf4c
Update mmengine/hooks/checkpoint_hook.py
KerwinKai Mar 9, 2023
e452ff4
Update mmengine/hooks/checkpoint_hook.py
KerwinKai Mar 9, 2023
a165e81
Update checkpoint_hook.py
KerwinKai Mar 9, 2023
a47a9ac
Update mmengine/hooks/checkpoint_hook.py
KerwinKai Mar 9, 2023
d104f2f
Update mmengine/hooks/checkpoint_hook.py
KerwinKai Mar 9, 2023
4f90093
Merge branch 'open-mmlab:main' into KerwinKai/add_publish_keys
KerwinKai Mar 10, 2023
687b8eb
Update checkpoint_hook.py
KerwinKai Mar 10, 2023
362bdb3
Merge branch 'open-mmlab:main' into KerwinKai/add_publish_keys
KerwinKai Mar 13, 2023
e12b852
Update mmengine/hooks/checkpoint_hook.py
KerwinKai Mar 13, 2023
8a57047
Merge branch 'open-mmlab:main' into KerwinKai/add_publish_keys
KerwinKai Mar 13, 2023
268660b
Merge branch 'open-mmlab:main' into KerwinKai/add_publish_keys
KerwinKai Mar 15, 2023
55f14f6
Add Test for publish model
KerwinKai Mar 15, 2023
42640d6
Update checkpoint_hook.py
KerwinKai Mar 15, 2023
38af8c1
Update test_checkpoint_hook.py
KerwinKai Mar 15, 2023
40aafd5
Fix file to pass pre-commit check
KerwinKai Mar 15, 2023
b236741
Update mmengine/hooks/checkpoint_hook.py
KerwinKai Mar 17, 2023
3cda2ae
Merge branch 'open-mmlab:main' into KerwinKai/add_publish_keys
KerwinKai Mar 17, 2023
e0750f3
Fix mypy warning
KerwinKai Mar 17, 2023
912ffbb
rm not necessary line in checkpoint_hook.py
KerwinKai Mar 17, 2023
e2ae6b1
Update mmengine/hooks/checkpoint_hook.py
KerwinKai Mar 19, 2023
60d708a
Merge branch 'open-mmlab:main' into KerwinKai/add_publish_keys
KerwinKai Mar 20, 2023
551d84b
rm unnecessary messages add to message_hub
KerwinKai Mar 20, 2023
5a0009e
Update mmengine/hooks/checkpoint_hook.py
KerwinKai Mar 24, 2023
520d92a
Update docs/zh_cn/tutorials/hook.md
KerwinKai Mar 24, 2023
7c77367
Update docs/zh_cn/tutorials/hook.md
KerwinKai Mar 24, 2023
21178eb
Merge branch 'open-mmlab:main' into KerwinKai/add_publish_keys
KerwinKai Mar 27, 2023
b89a2ca
update checkpoint hook and hook.md file
KerwinKai Mar 27, 2023
85e1f59
Apply suggestions from code review
zhouzaida Mar 28, 2023
e833a00
Apply suggestions from code review
zhouzaida Mar 28, 2023
28499ad
Update mmengine/hooks/checkpoint_hook.py
zhouzaida Mar 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions docs/en/tutorials/hook.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ runner.train()
- Save the most recent checkpoints
- Save the best checkpoints
- Specify the path to save the checkpoints
- Make checkpoints for publish

For more features, please read the [CheckpointHook API documentation](mmengine.hooks.CheckpointHook).

Expand Down Expand Up @@ -120,6 +121,14 @@ The four features mentioned above are described below.
default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=5, out_dir='/path/of/directory'))
```

- Make checkpoints for publish

If you want to automatically generate publishable checkpoints after training (remove unnecessary keys, such as optimizer state), you can set the `published_keys` parameter to choose which information to keep. Note: You need to set the `save_best` or `save_last` parameters accordingly so that the releasable checkpoints will be generated. Setting `save_best` will generate the releasable weights of the optimal checkpoint, and setting `save_last` will generate the releasable final checkpoint. These two parameters can also be set at the same time.

```python
default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=1, save_best='accuracy', rule='less', published_keys=['meta', 'state_dict']))
```

[LoggerHook](mmengine.hooks.LoggerHook) collects logs from different components of `Runner` and write them to terminal, JSON file, tensorboard and wandb .etc.

If we want to output (or save) the logs every 20 iterations, we can set the `interval` parameter and configure it as follows.
Expand Down
9 changes: 9 additions & 0 deletions docs/zh_cn/tutorials/hook.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ runner.train()
- 保存最新的多个权重
- 保存最优权重
- 指定保存权重的路径
- 制作发布用的权重

如需了解其他功能,请阅读 [CheckpointHook API 文档](mmengine.hooks.CheckpointHook)。

Expand Down Expand Up @@ -121,6 +122,14 @@ runner.train()
default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=5, out_dir='/path/of/directory'))
```

- 制作发布用的权重

如果你想在训练结束后自动生成可发布的权重(删除不需要的权重,例如优化器状态),你可以设置 `published_keys` 参数,选择需要保留的信息。注意:需要相应设置 `save_best` 或者 `save_last` 参数,这样才会生成可发布的权重,其中设置 `save_best` 会生成最优权重的可发布权重,设置 `save_last` 会生成最后一个权重的可发布权重,这两个参数也可同时设置。

```python
default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=1, save_best='accuracy', rule='less', published_keys=['meta', 'state_dict']))
```

### LoggerHook

[LoggerHook](mmengine.hooks.LoggerHook) 负责收集日志并把日志输出到终端或者输出到文件、TensorBoard 等后端。
Expand Down
76 changes: 75 additions & 1 deletion mmengine/hooks/checkpoint_hook.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
import hashlib
import logging
import os.path as osp
import pickle
from math import inf
from pathlib import Path
from typing import Callable, Dict, List, Optional, Sequence, Union
Expand Down Expand Up @@ -84,7 +86,11 @@ class CheckpointHook(Hook):
backend_args (dict, optional): Arguments to instantiate the
prefix of uri corresponding backend. Defaults to None.
New in v0.2.0.

published_keys (str, List[str], optional): If ``save_last`` is ``True``
or ``save_best`` is not ``None``, it will automatically
publish model with keys in the list after training.
Defaults to None.
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
`New in version 0.7.1.`
Examples:
>>> # Save best based on single metric
>>> CheckpointHook(interval=2, by_epoch=True, save_best='acc',
Expand All @@ -95,6 +101,9 @@ class CheckpointHook(Hook):
>>> # Save best based on multi metrics with different comparison rule
>>> CheckpointHook(interval=2, by_epoch=True,
>>> save_best=['FID', 'IS'], rule=['less', 'greater'])
>>> # Save best based on single metric and publish model after training
>>> CheckpointHook(interval=2, by_epoch=True, save_best='acc',
>>> rule='less', published_keys=['meta', 'state_dict'])
"""
out_dir: str

Expand Down Expand Up @@ -128,6 +137,7 @@ def __init__(self,
file_client_args: Optional[dict] = None,
filename_tmpl: Optional[str] = None,
backend_args: Optional[dict] = None,
published_keys: Union[str, List[str], None] = None,
**kwargs) -> None:
self.interval = interval
self.by_epoch = by_epoch
Expand Down Expand Up @@ -218,6 +228,20 @@ def __init__(self,
else:
self.best_ckpt_path_dict: Dict = dict()

# published keys
if not (isinstance(published_keys, str)
or is_list_of(published_keys, str) or published_keys is None):
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
raise TypeError(
'"published_keys" should be a str or list of str or None, '
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
f'but got {type(published_keys)}')

if isinstance(published_keys, str):
published_keys = [published_keys]
elif isinstance(published_keys, list):
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
assert len(published_keys) == len(set(published_keys)), (
'Find duplicate elements in "published_keys".')
self.published_keys = published_keys

def before_train(self, runner) -> None:
"""Finish all operations, related to checkpoint.

Expand Down Expand Up @@ -304,6 +328,56 @@ def after_val_epoch(self, runner, metrics):

self._save_best_checkpoint(runner, metrics)

def after_train(self, runner) -> None:
"""Publish the checkpoint after training.

Args:
runner (Runner): The runner of the training process.
"""
if self.published_keys is None:
return

if self.save_last and 'last_ckpt' in runner.message_hub.runtime_info:
last_ckpt = runner.message_hub.get_info('last_ckpt')
self._publish_model(runner, last_ckpt)

if getattr(self, 'best_ckpt_path', None) is not None:
self._publish_model(runner, str(self.best_ckpt_path))
if getattr(self, 'best_ckpt_path_dict', None) is not None:
for key, best_ckpt in self.best_ckpt_path_dict.items():
self._publish_model(runner, best_ckpt)

def _publish_model(self, runner, ckpt_path: str) -> None:
"""Remove unnecessary keys from ckpt_path and save the new checkpoint.

Args:
runner (Runner): The runner of the training process.
ckpt_path (str): The checkpoint path that ought to be published.
"""
from mmengine.runner import save_checkpoint
from mmengine.runner.checkpoint import _load_checkpoint
checkpoint = _load_checkpoint(ckpt_path)
assert self.published_keys is not None
removed_keys = []
for key in list(checkpoint.keys()):
if key not in self.published_keys:
removed_keys.append(key)
checkpoint.pop(key)
if removed_keys:
print_log(
f'Key {removed_keys} will be removed because they are not '
'found in published_keys. If you want to keep them, '
f'please set `{removed_keys}` in published_keys',
logger='current')
KerwinKai marked this conversation as resolved.
Show resolved Hide resolved
checkpoint_data = pickle.dumps(checkpoint)
sha = hashlib.sha256(checkpoint_data).hexdigest()
final_path = osp.splitext(ckpt_path)[0] + f'-{sha[:8]}.pth'
save_checkpoint(checkpoint, final_path)
print_log(
f'The checkpoint ({ckpt_path}) is published to '
f'{final_path}.',
logger='current')

def _save_checkpoint(self, runner) -> None:
"""Save the current checkpoint and delete outdated checkpoint.

Expand Down
5 changes: 4 additions & 1 deletion tests/test_hooks/test_checkpoint_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,10 @@ def test_with_runner(self, tmp_path):
type='CheckpointHook',
interval=save_interval,
filename_tmpl=tmpl,
by_epoch=True)
by_epoch=True,
save_best='test/acc',
rule='less',
published_keys=['meta', 'state_dict'])
runner = Runner(
model=ToyModel(),
work_dir=work_dir,
Expand Down