Skip to content

Commit 5f64a06

Browse files
sadra-barikbinvfdev-5
authored andcommitted
Feature add ZeRO support to Checkpoint in a distributed configuration (pytorch#2642)
* Implement feature * Fix bug in docstring * Fix bugs and tests * Handle pytorch<1.9.0 * Fix mypy error
1 parent 1afabbf commit 5f64a06

File tree

2 files changed

+54
-8
lines changed

2 files changed

+54
-8
lines changed

ignite/handlers/checkpoint.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,14 @@
1111

1212
import torch
1313
import torch.nn as nn
14+
from packaging.version import Version
15+
16+
if Version(torch.__version__) >= Version("1.9.0"):
17+
from torch.distributed.optim import ZeroRedundancyOptimizer
18+
19+
HAVE_ZERO = True
20+
else:
21+
HAVE_ZERO = False
1422

1523
import ignite.distributed as idist
1624
from ignite.base import Serializable
@@ -166,13 +174,14 @@ class Checkpoint(Serializable):
166174
> checkpoint_12345.pt
167175
168176
Note:
169-
This class is distributed configuration-friendly: it is not required to instantiate the class in rank 0
170-
process only.
177+
This class is distributed configuration-friendly: it is not required to instantiate the class in rank 0 only
178+
process. This class supports automatically distributed configuration and if used with
179+
:class:`~ignite.handlers.DiskSaver`, checkpoint is stored by rank 0 process.
171180
172181
.. warning::
173182
174-
When running on XLA devices, it should be run in all processes, otherwise application can get stuck on
175-
saving the checkpoint.
183+
When running on XLA devices or using :class:`~torch.distributed.optim.ZeroRedundancyOptimizer`, it
184+
should be run in all processes, otherwise application can get stuck while saving the checkpoint.
176185
177186
.. code-block:: python
178187
@@ -282,7 +291,7 @@ def __init__(
282291
filename_pattern: Optional[str] = None,
283292
include_self: bool = False,
284293
greater_or_equal: bool = False,
285-
save_on_rank: Optional[int] = 0,
294+
save_on_rank: int = 0,
286295
):
287296

288297
if not isinstance(to_save, collections.Mapping):
@@ -466,6 +475,10 @@ def _setup_checkpoint(self) -> Dict[str, Dict[Any, Any]]:
466475
for k, obj in self.to_save.items():
467476
if isinstance(obj, (nn.DataParallel, nn.parallel.DistributedDataParallel)):
468477
obj = obj.module
478+
elif HAVE_ZERO and isinstance(obj, ZeroRedundancyOptimizer):
479+
obj.consolidate_state_dict(to=self.save_on_rank)
480+
if self.save_on_rank != idist.get_rank():
481+
continue
469482
checkpoint[k] = obj.state_dict()
470483
return checkpoint
471484

@@ -782,7 +795,7 @@ def __init__(
782795
atomic: bool = True,
783796
create_dir: bool = True,
784797
require_empty: bool = True,
785-
save_on_rank: Optional[int] = 0,
798+
save_on_rank: int = 0,
786799
**kwargs: Any,
787800
):
788801
self.dirname = Path(dirname).expanduser()
@@ -948,7 +961,7 @@ def __init__(
948961
filename_pattern: Optional[str] = None,
949962
include_self: bool = False,
950963
greater_or_equal: bool = False,
951-
save_on_rank: Optional[int] = 0,
964+
save_on_rank: int = 0,
952965
**kwargs: Any,
953966
):
954967

tests/ignite/handlers/test_checkpoint.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1243,9 +1243,37 @@ def _test_checkpoint_load_objects_ddp(device):
12431243
Checkpoint.load_objects(to_load, checkpoint)
12441244

12451245

1246+
def _test_checkpoint_with_ZeRO(device, dirname, local_rank):
1247+
1248+
from torch.distributed.optim import ZeroRedundancyOptimizer
1249+
1250+
model = DummyModel().to(device)
1251+
opt = ZeroRedundancyOptimizer(model.parameters(), torch.optim.SGD, lr=0.01)
1252+
mocked_opt = MagicMock(ZeroRedundancyOptimizer, wraps=opt)
1253+
1254+
# A `step` should be called to optimizer state get populated.
1255+
out = model(torch.Tensor([1.0]))
1256+
out.backward()
1257+
mocked_opt.step()
1258+
1259+
to_save = {"model": model, "optim": mocked_opt}
1260+
checkpointer = Checkpoint(to_save, dirname, save_on_rank=1)
1261+
1262+
engine = Engine(lambda e, b: None)
1263+
checkpointer(engine)
1264+
1265+
mocked_opt.consolidate_state_dict.assert_called_once_with(to=1)
1266+
1267+
if local_rank == 1:
1268+
1269+
loaded_state_dict = torch.load(dirname / "checkpoint_0.pt", map_location=device)["optim"]
1270+
state_dict = opt.state_dict()
1271+
assert loaded_state_dict == state_dict
1272+
1273+
12461274
@pytest.mark.distributed
12471275
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
1248-
def test_distrib_gloo_cpu_or_gpu(distributed_context_single_node_gloo, get_rank_zero_dirname):
1276+
def test_distrib_gloo_cpu_or_gpu(distributed_context_single_node_gloo, dirname, get_rank_zero_dirname, local_rank):
12491277

12501278
device = idist.device()
12511279
rank_zero_dirname = get_rank_zero_dirname()
@@ -1254,6 +1282,11 @@ def test_distrib_gloo_cpu_or_gpu(distributed_context_single_node_gloo, get_rank_
12541282
_test_checkpoint_with_ddp(device)
12551283
_test_checkpoint_load_objects_ddp(device)
12561284

1285+
from ignite.handlers.checkpoint import HAVE_ZERO
1286+
1287+
if HAVE_ZERO:
1288+
_test_checkpoint_with_ZeRO(device, dirname, local_rank)
1289+
12571290

12581291
@pytest.mark.distributed
12591292
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")

0 commit comments

Comments
 (0)