Skip to content

Commit f239327

Browse files
Fix mypy error
1 parent f4c86a1 commit f239327

File tree

2 files changed

+14
-8
lines changed

2 files changed

+14
-8
lines changed

ignite/handlers/checkpoint.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@
1616
if Version(torch.__version__) >= Version("1.9.0"):
1717
from torch.distributed.optim import ZeroRedundancyOptimizer
1818

19+
HAVE_ZERO = True
20+
else:
21+
HAVE_ZERO = False
22+
1923
import ignite.distributed as idist
2024
from ignite.base import Serializable
2125
from ignite.engine import Engine, Events
@@ -287,7 +291,7 @@ def __init__(
287291
filename_pattern: Optional[str] = None,
288292
include_self: bool = False,
289293
greater_or_equal: bool = False,
290-
save_on_rank: Optional[int] = 0,
294+
save_on_rank: int = 0,
291295
):
292296

293297
if not isinstance(to_save, collections.Mapping):
@@ -471,7 +475,7 @@ def _setup_checkpoint(self) -> Dict[str, Dict[Any, Any]]:
471475
for k, obj in self.to_save.items():
472476
if isinstance(obj, (nn.DataParallel, nn.parallel.DistributedDataParallel)):
473477
obj = obj.module
474-
elif Version(torch.__version__) >= Version("1.9.0") and isinstance(obj, ZeroRedundancyOptimizer):
478+
elif HAVE_ZERO and isinstance(obj, ZeroRedundancyOptimizer):
475479
obj.consolidate_state_dict(to=self.save_on_rank)
476480
if self.save_on_rank != idist.get_rank():
477481
continue
@@ -791,7 +795,7 @@ def __init__(
791795
atomic: bool = True,
792796
create_dir: bool = True,
793797
require_empty: bool = True,
794-
save_on_rank: Optional[int] = 0,
798+
save_on_rank: int = 0,
795799
**kwargs: Any,
796800
):
797801
self.dirname = Path(dirname).expanduser()
@@ -957,7 +961,7 @@ def __init__(
957961
filename_pattern: Optional[str] = None,
958962
include_self: bool = False,
959963
greater_or_equal: bool = False,
960-
save_on_rank: Optional[int] = 0,
964+
save_on_rank: int = 0,
961965
**kwargs: Any,
962966
):
963967

tests/ignite/handlers/test_checkpoint.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,6 @@
1111
import torch.nn as nn
1212
from packaging.version import Version
1313

14-
if Version(torch.__version__) >= Version("1.9.0"):
15-
from torch.distributed.optim import ZeroRedundancyOptimizer
16-
1714
import ignite.distributed as idist
1815
from ignite.engine import Engine, Events, State
1916
from ignite.handlers import Checkpoint, DiskSaver, EarlyStopping, global_step_from_engine, ModelCheckpoint
@@ -1247,6 +1244,9 @@ def _test_checkpoint_load_objects_ddp(device):
12471244

12481245

12491246
def _test_checkpoint_with_ZeRO(device, dirname, local_rank):
1247+
1248+
from torch.distributed.optim import ZeroRedundancyOptimizer
1249+
12501250
model = DummyModel().to(device)
12511251
opt = ZeroRedundancyOptimizer(model.parameters(), torch.optim.SGD, lr=0.01)
12521252
mocked_opt = MagicMock(ZeroRedundancyOptimizer, wraps=opt)
@@ -1282,7 +1282,9 @@ def test_distrib_gloo_cpu_or_gpu(distributed_context_single_node_gloo, dirname,
12821282
_test_checkpoint_with_ddp(device)
12831283
_test_checkpoint_load_objects_ddp(device)
12841284

1285-
if Version(torch.__version__) >= Version("1.9.0"):
1285+
from ignite.handlers.checkpoint import HAVE_ZERO
1286+
1287+
if HAVE_ZERO:
12861288
_test_checkpoint_with_ZeRO(device, dirname, local_rank)
12871289

12881290

0 commit comments

Comments
 (0)