|
16 | 16 | if Version(torch.__version__) >= Version("1.9.0"):
|
17 | 17 | from torch.distributed.optim import ZeroRedundancyOptimizer
|
18 | 18 |
|
| 19 | + HAVE_ZERO = True |
| 20 | +else: |
| 21 | + HAVE_ZERO = False |
| 22 | + |
19 | 23 | import ignite.distributed as idist
|
20 | 24 | from ignite.base import Serializable
|
21 | 25 | from ignite.engine import Engine, Events
|
@@ -287,7 +291,7 @@ def __init__(
|
287 | 291 | filename_pattern: Optional[str] = None,
|
288 | 292 | include_self: bool = False,
|
289 | 293 | greater_or_equal: bool = False,
|
290 |
| - save_on_rank: Optional[int] = 0, |
| 294 | + save_on_rank: int = 0, |
291 | 295 | ):
|
292 | 296 |
|
293 | 297 | if not isinstance(to_save, collections.Mapping):
|
@@ -471,7 +475,7 @@ def _setup_checkpoint(self) -> Dict[str, Dict[Any, Any]]:
|
471 | 475 | for k, obj in self.to_save.items():
|
472 | 476 | if isinstance(obj, (nn.DataParallel, nn.parallel.DistributedDataParallel)):
|
473 | 477 | obj = obj.module
|
474 |
| - elif Version(torch.__version__) >= Version("1.9.0") and isinstance(obj, ZeroRedundancyOptimizer): |
| 478 | + elif HAVE_ZERO and isinstance(obj, ZeroRedundancyOptimizer): |
475 | 479 | obj.consolidate_state_dict(to=self.save_on_rank)
|
476 | 480 | if self.save_on_rank != idist.get_rank():
|
477 | 481 | continue
|
@@ -791,7 +795,7 @@ def __init__(
|
791 | 795 | atomic: bool = True,
|
792 | 796 | create_dir: bool = True,
|
793 | 797 | require_empty: bool = True,
|
794 |
| - save_on_rank: Optional[int] = 0, |
| 798 | + save_on_rank: int = 0, |
795 | 799 | **kwargs: Any,
|
796 | 800 | ):
|
797 | 801 | self.dirname = Path(dirname).expanduser()
|
@@ -957,7 +961,7 @@ def __init__(
|
957 | 961 | filename_pattern: Optional[str] = None,
|
958 | 962 | include_self: bool = False,
|
959 | 963 | greater_or_equal: bool = False,
|
960 |
| - save_on_rank: Optional[int] = 0, |
| 964 | + save_on_rank: int = 0, |
961 | 965 | **kwargs: Any,
|
962 | 966 | ):
|
963 | 967 |
|
|
0 commit comments