|
11 | 11 |
|
12 | 12 | import torch
|
13 | 13 | import torch.nn as nn
|
| 14 | +from packaging.version import Version |
| 15 | + |
| 16 | +if Version(torch.__version__) >= Version("1.9.0"): |
| 17 | + from torch.distributed.optim import ZeroRedundancyOptimizer |
| 18 | + |
| 19 | + HAVE_ZERO = True |
| 20 | +else: |
| 21 | + HAVE_ZERO = False |
14 | 22 |
|
15 | 23 | import ignite.distributed as idist
|
16 | 24 | from ignite.base import Serializable
|
@@ -166,13 +174,14 @@ class Checkpoint(Serializable):
|
166 | 174 | > checkpoint_12345.pt
|
167 | 175 |
|
168 | 176 | Note:
|
169 |
| - This class is distributed configuration-friendly: it is not required to instantiate the class in rank 0 |
170 |
| - process only. |
| 177 | + This class is distributed configuration-friendly: it is not required to instantiate the class in rank 0 only |
| 178 | + process. This class supports automatically distributed configuration and if used with |
| 179 | + :class:`~ignite.handlers.DiskSaver`, checkpoint is stored by rank 0 process. |
171 | 180 |
|
172 | 181 | .. warning::
|
173 | 182 |
|
174 |
| - When running on XLA devices, it should be run in all processes, otherwise application can get stuck on |
175 |
| - saving the checkpoint. |
| 183 | + When running on XLA devices or using :class:`~torch.distributed.optim.ZeroRedundancyOptimizer`, it |
| 184 | + should be run in all processes, otherwise application can get stuck while saving the checkpoint. |
176 | 185 |
|
177 | 186 | .. code-block:: python
|
178 | 187 |
|
@@ -282,7 +291,7 @@ def __init__(
|
282 | 291 | filename_pattern: Optional[str] = None,
|
283 | 292 | include_self: bool = False,
|
284 | 293 | greater_or_equal: bool = False,
|
285 |
| - save_on_rank: Optional[int] = 0, |
| 294 | + save_on_rank: int = 0, |
286 | 295 | ):
|
287 | 296 |
|
288 | 297 | if not isinstance(to_save, collections.Mapping):
|
@@ -466,6 +475,10 @@ def _setup_checkpoint(self) -> Dict[str, Dict[Any, Any]]:
|
466 | 475 | for k, obj in self.to_save.items():
|
467 | 476 | if isinstance(obj, (nn.DataParallel, nn.parallel.DistributedDataParallel)):
|
468 | 477 | obj = obj.module
|
| 478 | + elif HAVE_ZERO and isinstance(obj, ZeroRedundancyOptimizer): |
| 479 | + obj.consolidate_state_dict(to=self.save_on_rank) |
| 480 | + if self.save_on_rank != idist.get_rank(): |
| 481 | + continue |
469 | 482 | checkpoint[k] = obj.state_dict()
|
470 | 483 | return checkpoint
|
471 | 484 |
|
@@ -782,7 +795,7 @@ def __init__(
|
782 | 795 | atomic: bool = True,
|
783 | 796 | create_dir: bool = True,
|
784 | 797 | require_empty: bool = True,
|
785 |
| - save_on_rank: Optional[int] = 0, |
| 798 | + save_on_rank: int = 0, |
786 | 799 | **kwargs: Any,
|
787 | 800 | ):
|
788 | 801 | self.dirname = Path(dirname).expanduser()
|
@@ -948,7 +961,7 @@ def __init__(
|
948 | 961 | filename_pattern: Optional[str] = None,
|
949 | 962 | include_self: bool = False,
|
950 | 963 | greater_or_equal: bool = False,
|
951 |
| - save_on_rank: Optional[int] = 0, |
| 964 | + save_on_rank: int = 0, |
952 | 965 | **kwargs: Any,
|
953 | 966 | ):
|
954 | 967 |
|
|
0 commit comments