-
Notifications
You must be signed in to change notification settings - Fork 7
/
checkpointing.py
105 lines (84 loc) · 3.62 KB
/
checkpointing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import copy
import os
from typing import Any, Dict, Optional, Union, Type
import torch
from torch import nn, optim
class CheckpointManager(object):
r"""
A :class:`CheckpointManager` periodically serializes models and optimizer as .pth files during
training, and keeps track of best performing checkpoint based on an observed metric.
Extended Summary
----------------
It saves state dicts of models and optimizer as ``.pth`` files in a specified directory. This
class closely follows the API of PyTorch optimizers and learning rate schedulers.
Notes
-----
For :class:`~torch.nn.DataParallel` objects, ``.module.state_dict()`` is called instead of
``.state_dict()``.
Parameters
----------
models: Dict[str, torch.nn.Module]
Models which need to be serialized as a checkpoint.
optimizer: torch.optim.Optimizer
Optimizer which needs to be serialized as a checkpoint.
serialization_dir: str
Path to an empty or non-existent directory to save checkpoints.
mode: str, optional (default="max")
One of ``min``, ``max``. In ``min`` mode, best checkpoint will be recorded when metric
hits a lower value; in `max` mode it will be recorded when metric hits a higher value.
filename_prefix: str, optional (default="checkpoint")
Prefix of the to-be-saved checkpoint files.
Examples
--------
>>> model = torch.nn.Linear(10, 2)
>>> optimizer = torch.optim.Adam(model.parameters())
>>> ckpt_manager = CheckpointManager({"model": model}, optimizer, "/tmp/ckpt", mode="min")
>>> num_epochs = 20
>>> for epoch in range(num_epochs):
... train(model)
... val_loss = validate(model)
... ckpt_manager.step(val_loss, epoch)
"""
def __init__(
self,
models: Union[nn.Module, Dict[str, nn.Module]],
serialization_dir: str,
mode: str = "max",
filename_prefix: str = "model",
):
# Convert single model to a dict.
if isinstance(models, nn.Module):
models = {"model": models}
for key in models:
if not isinstance(models[key], nn.Module):
raise TypeError("{} is not a Module".format(type(models).__name__))
self._models = models
self._serialization_dir = serialization_dir
self._mode = mode
self._filename_prefix = filename_prefix
# Initialize members to hold state dict of best checkpoint and its performance.
self._best_metric: Optional[Union[float, torch.Tensor]] = None
def step(self, metric: Union[float, torch.Tensor]):
r"""Serialize checkpoint and update best checkpoint based on metric and mode."""
# Update best checkpoint based on metric and metric mode.
if not self._best_metric:
self._best_metric = metric
models_state_dict: Dict[str, Any] = {}
for key in self._models:
if isinstance(self._models[key], nn.DataParallel):
models_state_dict[key] = self._models[key].module.state_dict()
else:
models_state_dict[key] = self._models[key].state_dict()
if (self._mode == "min" and metric <= self._best_metric) or (
self._mode == "max" and metric >= self._best_metric
):
self._best_metric = metric
# Serialize checkpoint corresponding to current epoch (or iteration).
torch.save(
models_state_dict,
os.path.join(
self._serialization_dir, f"{self._filename_prefix}-best.pth"
),
)
return True
return False