@@ -327,6 +327,18 @@ def __init__(
327
327
self .include_self = include_self
328
328
self .greater_or_equal = greater_or_equal
329
329
330
+ def _get_filename_pattern (self , global_step : Optional [int ]) -> str :
331
+ if self .filename_pattern is None :
332
+ filename_pattern = self .setup_filename_pattern (
333
+ with_prefix = len (self .filename_prefix ) > 0 ,
334
+ with_score = self .score_function is not None ,
335
+ with_score_name = self .score_name is not None ,
336
+ with_global_step = global_step is not None ,
337
+ )
338
+ else :
339
+ filename_pattern = self .filename_pattern
340
+ return filename_pattern
341
+
330
342
def reset (self ) -> None :
331
343
"""Method to reset saved checkpoint names.
332
344
@@ -402,15 +414,7 @@ def __call__(self, engine: Engine) -> None:
402
414
name = k
403
415
checkpoint = checkpoint [name ]
404
416
405
- if self .filename_pattern is None :
406
- filename_pattern = self .setup_filename_pattern (
407
- with_prefix = len (self .filename_prefix ) > 0 ,
408
- with_score = self .score_function is not None ,
409
- with_score_name = self .score_name is not None ,
410
- with_global_step = global_step is not None ,
411
- )
412
- else :
413
- filename_pattern = self .filename_pattern
417
+ filename_pattern = self ._get_filename_pattern (global_step )
414
418
415
419
filename_dict = {
416
420
"filename_prefix" : self .filename_prefix ,
@@ -519,41 +523,51 @@ def _check_objects(objs: Mapping, attr: str) -> None:
519
523
raise TypeError (f"Object { type (obj )} should have `{ attr } ` method" )
520
524
521
525
@staticmethod
522
- def load_objects (to_load : Mapping , checkpoint : Union [str , Mapping ], ** kwargs : Any ) -> None :
526
+ def load_objects (to_load : Mapping , checkpoint : Union [str , Mapping , Path ], ** kwargs : Any ) -> None :
523
527
"""Helper method to apply ``load_state_dict`` on the objects from ``to_load`` using states from ``checkpoint``.
524
528
525
529
Args:
526
530
to_load: a dictionary with objects, e.g. `{"model": model, "optimizer": optimizer, ...}`
527
- checkpoint: a string filepath or a dictionary with state_dicts to load, e.g. `{"model": model_state_dict,
528
- " optimizer": opt_state_dict}`. If `to_load` contains a single key, then checkpoint can contain
529
- directly corresponding state_dict.
531
+ checkpoint: a path, a string filepath or a dictionary with state_dicts to load, e.g.
532
+ `{"model": model_state_dict, " optimizer": opt_state_dict}`. If `to_load` contains a single key,
533
+ then checkpoint can contain directly corresponding state_dict.
530
534
kwargs: Keyword arguments accepted for `nn.Module.load_state_dict()`. Passing `strict=False` enables
531
535
the user to load part of the pretrained model (useful for example, in Transfer Learning)
532
536
533
537
Examples:
534
538
.. code-block:: python
535
539
540
+ import tempfile
541
+ from pathlib import Path
542
+
536
543
import torch
544
+
537
545
from ignite.engine import Engine, Events
538
546
from ignite.handlers import ModelCheckpoint, Checkpoint
547
+
539
548
trainer = Engine(lambda engine, batch: None)
540
- handler = ModelCheckpoint('/tmp/models', 'myprefix', n_saved=None, create_dir=True)
541
- model = torch.nn.Linear(3, 3)
542
- optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
543
- to_save = {"weights": model, "optimizer": optimizer}
544
- trainer.add_event_handler(Events.EPOCH_COMPLETED(every=2), handler, to_save)
545
- trainer.run(torch.randn(10, 1), 5)
546
549
547
- to_load = to_save
548
- checkpoint_fp = "/tmp/models/myprefix_checkpoint_40.pth"
549
- checkpoint = torch.load(checkpoint_fp)
550
- Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint)
550
+ with tempfile.TemporaryDirectory() as tmpdirname:
551
+ handler = ModelCheckpoint(tmpdirname, 'myprefix', n_saved=None, create_dir=True)
552
+
553
+ model = torch.nn.Linear(3, 3)
554
+ optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
555
+
556
+ to_save = {"weights": model, "optimizer": optimizer}
557
+
558
+ trainer.add_event_handler(Events.EPOCH_COMPLETED(every=2), handler, to_save)
559
+ trainer.run(torch.randn(10, 1), 5)
560
+
561
+ to_load = to_save
562
+ checkpoint_fp = Path(tmpdirname) / 'myprefix_checkpoint_40.pt'
563
+ checkpoint = torch.load(checkpoint_fp)
564
+ Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint)
551
565
552
- # or using a string for checkpoint filepath
566
+ # or using a string for checkpoint filepath
553
567
554
- to_load = to_save
555
- checkpoint_fp = "/tmp/models/ myprefix_checkpoint_40.pth"
556
- Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint_fp)
568
+ to_load = to_save
569
+ checkpoint_fp = Path(tmpdirname) / ' myprefix_checkpoint_40.pt'
570
+ Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint_fp)
557
571
558
572
Note:
559
573
If ``to_load`` contains objects of type torch `DistributedDataParallel`_ or
@@ -564,13 +578,13 @@ def load_objects(to_load: Mapping, checkpoint: Union[str, Mapping], **kwargs: An
564
578
.. _DataParallel: https://pytorch.org/docs/stable/generated/torch.nn.DataParallel.html
565
579
"""
566
580
567
- if isinstance (checkpoint , str ):
581
+ if isinstance (checkpoint , ( str , Path ) ):
568
582
checkpoint_obj = torch .load (checkpoint )
569
583
else :
570
584
checkpoint_obj = checkpoint
571
585
572
586
Checkpoint ._check_objects (to_load , "load_state_dict" )
573
- if not isinstance (checkpoint , (collections .Mapping , str )):
587
+ if not isinstance (checkpoint , (collections .Mapping , str , Path )):
574
588
raise TypeError (f"Argument checkpoint should be a string or a dictionary, but given { type (checkpoint )} " )
575
589
576
590
if len (kwargs ) > 1 or any (k for k in kwargs if k not in ["strict" ]):
@@ -599,6 +613,82 @@ def _load_object(obj: Any, chkpt_obj: Any) -> None:
599
613
raise ValueError (f"Object labeled by '{ k } ' from `to_load` is not found in the checkpoint" )
600
614
_load_object (obj , checkpoint_obj [k ])
601
615
616
+ def reload_objects (self , to_load : Mapping , load_kwargs : Optional [Dict ] = None , ** filename_components : Any ) -> None :
617
+ """Helper method to apply ``load_state_dict`` on the objects from ``to_load``. Filename components such as
618
+ name, score and global state can be configured.
619
+
620
+ Args:
621
+ to_load: a dictionary with objects, e.g. `{"model": model, "optimizer": optimizer, ...}`
622
+ load_kwargs: Keyword arguments accepted for `nn.Module.load_state_dict()`. Passing `strict=False` enables
623
+ the user to load part of the pretrained model (useful for example, in Transfer Learning)
624
+ filename_components: Filename components used to define the checkpoint file path.
625
+ Keyword arguments accepted are `name`, `score` and `global_state`.
626
+
627
+ Examples:
628
+ .. code-block:: python
629
+
630
+ import tempfile
631
+
632
+ import torch
633
+
634
+ from ignite.engine import Engine, Events
635
+ from ignite.handlers import ModelCheckpoint, Checkpoint
636
+
637
+ trainer = Engine(lambda engine, batch: None)
638
+
639
+ with tempfile.TemporaryDirectory() as tmpdirname:
640
+ checkpoint = ModelCheckpoint(tmpdirname, 'myprefix', n_saved=None, create_dir=True)
641
+
642
+ model = torch.nn.Linear(3, 3)
643
+ optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
644
+
645
+ to_save = {"weights": model, "optimizer": optimizer}
646
+
647
+ trainer.add_event_handler(Events.EPOCH_COMPLETED(every=2), checkpoint, to_save)
648
+ trainer.run(torch.randn(10, 1), 5)
649
+
650
+ to_load = to_save
651
+ # load checkpoint myprefix_checkpoint_40.pt
652
+ checkpoint.load_objects(to_load=to_load, global_step=40)
653
+
654
+ Note:
655
+ If ``to_load`` contains objects of type torch `DistributedDataParallel`_ or
656
+ `DataParallel`_, method ``load_state_dict`` will applied to their internal wrapped model (``obj.module``).
657
+
658
+ .. _DistributedDataParallel: https://pytorch.org/docs/stable/generated/
659
+ torch.nn.parallel.DistributedDataParallel.html
660
+ .. _DataParallel: https://pytorch.org/docs/stable/generated/torch.nn.DataParallel.html
661
+ """
662
+
663
+ global_step = filename_components .get ("global_step" , None )
664
+
665
+ filename_pattern = self ._get_filename_pattern (global_step )
666
+
667
+ checkpoint = self ._setup_checkpoint ()
668
+ name = "checkpoint"
669
+ if len (checkpoint ) == 1 :
670
+ for k in checkpoint :
671
+ name = k
672
+ name = filename_components .get ("name" , name )
673
+ score = filename_components .get ("score" , None )
674
+
675
+ filename_dict = {
676
+ "filename_prefix" : self .filename_prefix ,
677
+ "ext" : self .ext ,
678
+ "name" : name ,
679
+ "score_name" : self .score_name ,
680
+ "score" : score ,
681
+ "global_step" : global_step ,
682
+ }
683
+
684
+ checkpoint_fp = filename_pattern .format (** filename_dict )
685
+
686
+ path = self .save_handler .dirname / checkpoint_fp
687
+
688
+ load_kwargs = {} if load_kwargs is None else load_kwargs
689
+
690
+ Checkpoint .load_objects (to_load = to_load , checkpoint = path , ** load_kwargs )
691
+
602
692
def state_dict (self ) -> "OrderedDict[str, List[Tuple[int, str]]]" :
603
693
"""Method returns state dict with saved items: list of ``(priority, filename)`` pairs.
604
694
Can be used to save internal state of the class.
0 commit comments