Skip to content

Commit 06018c6

Browse files
sdesrozisDesroziersvfdev-5
authored
Add filename components in Checkpoint (#2498)
* add filename components in Checkpoint * remove debug print * add doc * add Path to valid type * handle Path as a valid type * fix mypy * remove mypy (error with win32 only) * add mandatory empty lines dor doctest * rename private function Co-authored-by: Desroziers <sylvain.desroziers@michelin.com> Co-authored-by: vfdev <vfdev.5@gmail.com>
1 parent 85bb162 commit 06018c6

File tree

2 files changed

+126
-29
lines changed

2 files changed

+126
-29
lines changed

ignite/handlers/checkpoint.py

Lines changed: 119 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,18 @@ def __init__(
327327
self.include_self = include_self
328328
self.greater_or_equal = greater_or_equal
329329

330+
def _get_filename_pattern(self, global_step: Optional[int]) -> str:
331+
if self.filename_pattern is None:
332+
filename_pattern = self.setup_filename_pattern(
333+
with_prefix=len(self.filename_prefix) > 0,
334+
with_score=self.score_function is not None,
335+
with_score_name=self.score_name is not None,
336+
with_global_step=global_step is not None,
337+
)
338+
else:
339+
filename_pattern = self.filename_pattern
340+
return filename_pattern
341+
330342
def reset(self) -> None:
331343
"""Method to reset saved checkpoint names.
332344
@@ -402,15 +414,7 @@ def __call__(self, engine: Engine) -> None:
402414
name = k
403415
checkpoint = checkpoint[name]
404416

405-
if self.filename_pattern is None:
406-
filename_pattern = self.setup_filename_pattern(
407-
with_prefix=len(self.filename_prefix) > 0,
408-
with_score=self.score_function is not None,
409-
with_score_name=self.score_name is not None,
410-
with_global_step=global_step is not None,
411-
)
412-
else:
413-
filename_pattern = self.filename_pattern
417+
filename_pattern = self._get_filename_pattern(global_step)
414418

415419
filename_dict = {
416420
"filename_prefix": self.filename_prefix,
@@ -519,41 +523,51 @@ def _check_objects(objs: Mapping, attr: str) -> None:
519523
raise TypeError(f"Object {type(obj)} should have `{attr}` method")
520524

521525
@staticmethod
522-
def load_objects(to_load: Mapping, checkpoint: Union[str, Mapping], **kwargs: Any) -> None:
526+
def load_objects(to_load: Mapping, checkpoint: Union[str, Mapping, Path], **kwargs: Any) -> None:
523527
"""Helper method to apply ``load_state_dict`` on the objects from ``to_load`` using states from ``checkpoint``.
524528
525529
Args:
526530
to_load: a dictionary with objects, e.g. `{"model": model, "optimizer": optimizer, ...}`
527-
checkpoint: a string filepath or a dictionary with state_dicts to load, e.g. `{"model": model_state_dict,
528-
"optimizer": opt_state_dict}`. If `to_load` contains a single key, then checkpoint can contain
529-
directly corresponding state_dict.
531+
checkpoint: a path, a string filepath or a dictionary with state_dicts to load, e.g.
532+
`{"model": model_state_dict, "optimizer": opt_state_dict}`. If `to_load` contains a single key,
533+
then checkpoint can contain directly corresponding state_dict.
530534
kwargs: Keyword arguments accepted for `nn.Module.load_state_dict()`. Passing `strict=False` enables
531535
the user to load part of the pretrained model (useful for example, in Transfer Learning)
532536
533537
Examples:
534538
.. code-block:: python
535539
540+
import tempfile
541+
from pathlib import Path
542+
536543
import torch
544+
537545
from ignite.engine import Engine, Events
538546
from ignite.handlers import ModelCheckpoint, Checkpoint
547+
539548
trainer = Engine(lambda engine, batch: None)
540-
handler = ModelCheckpoint('/tmp/models', 'myprefix', n_saved=None, create_dir=True)
541-
model = torch.nn.Linear(3, 3)
542-
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
543-
to_save = {"weights": model, "optimizer": optimizer}
544-
trainer.add_event_handler(Events.EPOCH_COMPLETED(every=2), handler, to_save)
545-
trainer.run(torch.randn(10, 1), 5)
546549
547-
to_load = to_save
548-
checkpoint_fp = "/tmp/models/myprefix_checkpoint_40.pth"
549-
checkpoint = torch.load(checkpoint_fp)
550-
Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint)
550+
with tempfile.TemporaryDirectory() as tmpdirname:
551+
handler = ModelCheckpoint(tmpdirname, 'myprefix', n_saved=None, create_dir=True)
552+
553+
model = torch.nn.Linear(3, 3)
554+
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
555+
556+
to_save = {"weights": model, "optimizer": optimizer}
557+
558+
trainer.add_event_handler(Events.EPOCH_COMPLETED(every=2), handler, to_save)
559+
trainer.run(torch.randn(10, 1), 5)
560+
561+
to_load = to_save
562+
checkpoint_fp = Path(tmpdirname) / 'myprefix_checkpoint_40.pt'
563+
checkpoint = torch.load(checkpoint_fp)
564+
Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint)
551565
552-
# or using a string for checkpoint filepath
566+
# or using a string for checkpoint filepath
553567
554-
to_load = to_save
555-
checkpoint_fp = "/tmp/models/myprefix_checkpoint_40.pth"
556-
Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint_fp)
568+
to_load = to_save
569+
checkpoint_fp = Path(tmpdirname) / 'myprefix_checkpoint_40.pt'
570+
Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint_fp)
557571
558572
Note:
559573
If ``to_load`` contains objects of type torch `DistributedDataParallel`_ or
@@ -564,13 +578,13 @@ def load_objects(to_load: Mapping, checkpoint: Union[str, Mapping], **kwargs: An
564578
.. _DataParallel: https://pytorch.org/docs/stable/generated/torch.nn.DataParallel.html
565579
"""
566580

567-
if isinstance(checkpoint, str):
581+
if isinstance(checkpoint, (str, Path)):
568582
checkpoint_obj = torch.load(checkpoint)
569583
else:
570584
checkpoint_obj = checkpoint
571585

572586
Checkpoint._check_objects(to_load, "load_state_dict")
573-
if not isinstance(checkpoint, (collections.Mapping, str)):
587+
if not isinstance(checkpoint, (collections.Mapping, str, Path)):
574588
raise TypeError(f"Argument checkpoint should be a string or a dictionary, but given {type(checkpoint)}")
575589

576590
if len(kwargs) > 1 or any(k for k in kwargs if k not in ["strict"]):
@@ -599,6 +613,82 @@ def _load_object(obj: Any, chkpt_obj: Any) -> None:
599613
raise ValueError(f"Object labeled by '{k}' from `to_load` is not found in the checkpoint")
600614
_load_object(obj, checkpoint_obj[k])
601615

616+
def reload_objects(self, to_load: Mapping, load_kwargs: Optional[Dict] = None, **filename_components: Any) -> None:
617+
"""Helper method to apply ``load_state_dict`` on the objects from ``to_load``. Filename components such as
618+
name, score and global state can be configured.
619+
620+
Args:
621+
to_load: a dictionary with objects, e.g. `{"model": model, "optimizer": optimizer, ...}`
622+
load_kwargs: Keyword arguments accepted for `nn.Module.load_state_dict()`. Passing `strict=False` enables
623+
the user to load part of the pretrained model (useful for example, in Transfer Learning)
624+
filename_components: Filename components used to define the checkpoint file path.
625+
Keyword arguments accepted are `name`, `score` and `global_state`.
626+
627+
Examples:
628+
.. code-block:: python
629+
630+
import tempfile
631+
632+
import torch
633+
634+
from ignite.engine import Engine, Events
635+
from ignite.handlers import ModelCheckpoint, Checkpoint
636+
637+
trainer = Engine(lambda engine, batch: None)
638+
639+
with tempfile.TemporaryDirectory() as tmpdirname:
640+
checkpoint = ModelCheckpoint(tmpdirname, 'myprefix', n_saved=None, create_dir=True)
641+
642+
model = torch.nn.Linear(3, 3)
643+
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
644+
645+
to_save = {"weights": model, "optimizer": optimizer}
646+
647+
trainer.add_event_handler(Events.EPOCH_COMPLETED(every=2), checkpoint, to_save)
648+
trainer.run(torch.randn(10, 1), 5)
649+
650+
to_load = to_save
651+
# load checkpoint myprefix_checkpoint_40.pt
652+
checkpoint.load_objects(to_load=to_load, global_step=40)
653+
654+
Note:
655+
If ``to_load`` contains objects of type torch `DistributedDataParallel`_ or
656+
`DataParallel`_, method ``load_state_dict`` will applied to their internal wrapped model (``obj.module``).
657+
658+
.. _DistributedDataParallel: https://pytorch.org/docs/stable/generated/
659+
torch.nn.parallel.DistributedDataParallel.html
660+
.. _DataParallel: https://pytorch.org/docs/stable/generated/torch.nn.DataParallel.html
661+
"""
662+
663+
global_step = filename_components.get("global_step", None)
664+
665+
filename_pattern = self._get_filename_pattern(global_step)
666+
667+
checkpoint = self._setup_checkpoint()
668+
name = "checkpoint"
669+
if len(checkpoint) == 1:
670+
for k in checkpoint:
671+
name = k
672+
name = filename_components.get("name", name)
673+
score = filename_components.get("score", None)
674+
675+
filename_dict = {
676+
"filename_prefix": self.filename_prefix,
677+
"ext": self.ext,
678+
"name": name,
679+
"score_name": self.score_name,
680+
"score": score,
681+
"global_step": global_step,
682+
}
683+
684+
checkpoint_fp = filename_pattern.format(**filename_dict)
685+
686+
path = self.save_handler.dirname / checkpoint_fp
687+
688+
load_kwargs = {} if load_kwargs is None else load_kwargs
689+
690+
Checkpoint.load_objects(to_load=to_load, checkpoint=path, **load_kwargs)
691+
602692
def state_dict(self) -> "OrderedDict[str, List[Tuple[int, str]]]":
603693
"""Method returns state dict with saved items: list of ``(priority, filename)`` pairs.
604694
Can be used to save internal state of the class.

tests/ignite/handlers/test_checkpoint.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -576,6 +576,9 @@ def test_model_checkpoint_simple_recovery(dirname):
576576
assert fname.exists()
577577
loaded_objects = torch.load(fname)
578578
assert loaded_objects == model.state_dict()
579+
to_load = {"model": DummyModel()}
580+
h.reload_objects(to_load=to_load, global_step=1)
581+
assert to_load["model"].state_dict() == model.state_dict()
579582

580583

581584
def test_model_checkpoint_simple_recovery_from_existing_non_empty(dirname):
@@ -600,6 +603,9 @@ def _test(ext, require_empty):
600603
assert previous_fname.exists()
601604
loaded_objects = torch.load(fname)
602605
assert loaded_objects == model.state_dict()
606+
to_load = {"model": DummyModel()}
607+
h.reload_objects(to_load=to_load, global_step=1)
608+
assert to_load["model"].state_dict() == model.state_dict()
603609
fname.unlink()
604610

605611
_test(".txt", require_empty=True)
@@ -1118,6 +1124,7 @@ def _get_multiple_objs_to_save():
11181124
assert str(dirname / _PREFIX) in str(fname)
11191125
assert fname.exists()
11201126
Checkpoint.load_objects(to_save, str(fname))
1127+
Checkpoint.load_objects(to_save, fname)
11211128
fname.unlink()
11221129

11231130
# case: multiple objects

0 commit comments

Comments
 (0)