Skip to content

Commit 97a6186

Browse files
carmoccaBorda
andauthored
Sync module states during non-fit (Lightning-AI#17370)
Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
1 parent 9becc15 commit 97a6186

File tree

6 files changed

+173
-193
lines changed

6 files changed

+173
-193
lines changed

docs/source-pytorch/api_references.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,6 @@ utilities
243243
combined_loader
244244
data
245245
deepspeed
246-
distributed
247246
memory
248247
model_summary
249248
parsing

src/lightning/pytorch/overrides/distributed.py

Lines changed: 135 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,16 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import itertools
15-
from typing import Any, cast, Dict, Iterable, Iterator, List, Optional, Sized, Union
15+
from typing import Any, Callable, cast, Dict, Iterable, Iterator, List, Optional, Sized, Union
1616

1717
import torch
1818
from torch import Tensor
19-
from torch.nn.parallel import DistributedDataParallel
19+
from torch.nn.parallel.distributed import DistributedDataParallel
2020
from torch.utils.data import BatchSampler, DistributedSampler, Sampler
2121

2222
from lightning.fabric.utilities.distributed import _DatasetSamplerWrapper
23+
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_12
24+
from lightning.pytorch.utilities.rank_zero import rank_zero_debug, rank_zero_info
2325

2426

2527
def _find_tensors(
@@ -37,7 +39,7 @@ def _find_tensors(
3739

3840
# In manual_optimization, we need to call reducer prepare_for_backward.
3941
# Note: Keep track of PyTorch DDP and update if there is a change
40-
# https://github.com/pytorch/pytorch/blob/v1.7.1/torch/nn/parallel/distributed.py#L626-L638
42+
# https://github.com/pytorch/pytorch/blob/v2.0.0/torch/nn/parallel/distributed.py#L1163-L1178
4143
def prepare_for_backward(model: DistributedDataParallel, output: Any) -> None:
4244
# `prepare_for_backward` is `DistributedDataParallel` specific.
4345
if torch.is_grad_enabled() and model.require_backward_grad_sync:
@@ -47,14 +49,143 @@ def prepare_for_backward(model: DistributedDataParallel, output: Any) -> None:
4749
# because we need to figure out which parameters were used during
4850
# this forward pass, to ensure we short circuit reduction for any
4951
# unused parameters. Only if `find_unused_parameters` is set.
50-
args = list(_find_tensors(output)) if model.find_unused_parameters else []
52+
args = list(_find_tensors(output)) if model.find_unused_parameters and not model.static_graph else []
5153
reducer = cast(torch._C._distributed_c10d.Reducer, model.reducer)
5254
reducer._rebuild_buckets() # avoids "INTERNAL ASSERT FAILED" with `find_unused_parameters=False`
5355
reducer.prepare_for_backward(args)
5456
else:
5557
model.require_forward_param_sync = False
5658

5759

60+
def _register_ddp_comm_hook(
61+
model: DistributedDataParallel,
62+
ddp_comm_state: Optional[object] = None,
63+
ddp_comm_hook: Optional[Callable] = None,
64+
ddp_comm_wrapper: Optional[Callable] = None,
65+
) -> None:
66+
"""Function to register communication hook for DDP model https://pytorch.org/docs/master/ddp_comm_hooks.html.
67+
68+
Args:
69+
model:
70+
DDP model
71+
ddp_comm_state:
72+
state is passed to the hook and can be used to maintain
73+
and update any state information that users would like to
74+
maintain as part of the training process. Examples: error
75+
feedback in gradient compression, peers to communicate with
76+
next in GossipGrad etc.
77+
ddp_comm_hook:
78+
hook(state: object, bucket: dist._GradBucket) -> torch.futures.Future
79+
80+
This callable function is called once the bucket is ready. The
81+
hook can perform whatever processing is needed and return
82+
a Future indicating completion of any async work (ex: allreduce).
83+
If the hook doesn't perform any communication, it can also
84+
just return a completed Future. The Future should hold the
85+
new value of grad bucket's tensors. Once a bucket is ready,
86+
c10d reducer would call this hook and use the tensors returned
87+
by the Future and copy grads to individual parameters.
88+
ddp_comm_wrapper:
89+
communication hook wrapper to support a communication hook such
90+
as FP16 compression as wrapper, which could be combined with
91+
ddp_comm_hook
92+
93+
Examples:
94+
95+
>>> from torch.distributed.algorithms.ddp_comm_hooks import ( # doctest: +SKIP
96+
... default_hooks as default,
97+
... powerSGD_hook as powerSGD,
98+
... post_localSGD_hook as post_localSGD,
99+
... )
100+
>>> # fp16_compress_hook for compress gradients
101+
>>> ddp_model = ...
102+
>>> _register_ddp_comm_hook( # doctest: +SKIP
103+
... model=ddp_model,
104+
... ddp_comm_hook=default.fp16_compress_hook,
105+
... )
106+
>>> # powerSGD_hook
107+
>>> ddp_model = ...
108+
>>> _register_ddp_comm_hook( # doctest: +SKIP
109+
... model=ddp_model,
110+
... ddp_comm_state=powerSGD.PowerSGDState(
111+
... process_group=None,
112+
... matrix_approximation_rank=1,
113+
... start_powerSGD_iter=5000,
114+
... ),
115+
... ddp_comm_hook=powerSGD.powerSGD_hook,
116+
... )
117+
>>> # post_localSGD_hook
118+
>>> subgroup, _ = torch.distributed.new_subgroups() # doctest: +SKIP
119+
>>> ddp_model = ...
120+
>>> _register_ddp_comm_hook( # doctest: +SKIP
121+
... model=ddp_model,
122+
... state=post_localSGD.PostLocalSGDState(
123+
... process_group=None,
124+
... subgroup=subgroup,
125+
... start_localSGD_iter=1_000,
126+
... ),
127+
... ddp_comm_hook=post_localSGD.post_localSGD_hook,
128+
... )
129+
>>> # fp16_compress_wrapper combined with other communication hook
130+
>>> ddp_model = ...
131+
>>> _register_ddp_comm_hook( # doctest: +SKIP
132+
... model=ddp_model,
133+
... ddp_comm_state=powerSGD.PowerSGDState(
134+
... process_group=None,
135+
... matrix_approximation_rank=1,
136+
... start_powerSGD_iter=5000,
137+
... ),
138+
... ddp_comm_hook=powerSGD.powerSGD_hook,
139+
... ddp_comm_wrapper=default.fp16_compress_wrapper,
140+
... )
141+
"""
142+
if ddp_comm_hook is None:
143+
return
144+
# inform mypy that ddp_comm_hook is callable
145+
ddp_comm_hook: Callable = ddp_comm_hook
146+
147+
if ddp_comm_wrapper is not None:
148+
rank_zero_info(
149+
f"DDP comm wrapper is provided, apply {ddp_comm_wrapper.__qualname__}({ddp_comm_hook.__qualname__})."
150+
)
151+
ddp_comm_hook = ddp_comm_wrapper(ddp_comm_hook)
152+
153+
rank_zero_debug(f"Registering DDP comm hook: {ddp_comm_hook.__qualname__}.")
154+
model.register_comm_hook(state=ddp_comm_state, hook=ddp_comm_hook)
155+
156+
157+
def _sync_module_states(module: torch.nn.Module) -> None:
158+
"""Taken from https://github.com/pytorch/pytorch/blob/v2.0.0/torch/nn/parallel/distributed.py#L675-L682."""
159+
parameters_to_ignore = (
160+
set(module._ddp_params_and_buffers_to_ignore) # type: ignore[arg-type]
161+
if hasattr(module, "_ddp_params_and_buffers_to_ignore")
162+
else set()
163+
)
164+
from torch.distributed.distributed_c10d import _get_default_group
165+
166+
if not _TORCH_GREATER_EQUAL_1_12:
167+
module_states = []
168+
for name, param in module.named_parameters():
169+
if name not in parameters_to_ignore:
170+
module_states.append(param.detach())
171+
for name, buffer in module.named_buffers():
172+
if name not in parameters_to_ignore:
173+
module_states.append(buffer.detach())
174+
if len(module_states) > 0:
175+
torch.distributed._broadcast_coalesced(_get_default_group(), module_states, 250 * 1024 * 1024, 0)
176+
return
177+
178+
from torch.distributed.utils import _sync_module_states as torch_sync_module_states
179+
180+
torch_sync_module_states(
181+
module,
182+
_get_default_group(),
183+
250 * 1024 * 1024,
184+
src=0,
185+
params_and_buffers_to_ignore=parameters_to_ignore,
186+
)
187+
188+
58189
class UnrepeatedDistributedSampler(DistributedSampler):
59190
"""A fork of the PyTorch DistributedSampler that doesn't repeat data, instead allowing the number of batches
60191
per process to be off-by-one from each other. This makes this sampler usable for predictions (it's

src/lightning/pytorch/strategies/ddp.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,12 @@
3838
from lightning.fabric.utilities.types import ReduceOp
3939
from lightning.pytorch.core.optimizer import LightningOptimizer
4040
from lightning.pytorch.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase
41-
from lightning.pytorch.overrides.distributed import prepare_for_backward
41+
from lightning.pytorch.overrides.distributed import _register_ddp_comm_hook, _sync_module_states, prepare_for_backward
4242
from lightning.pytorch.plugins.precision import PrecisionPlugin
4343
from lightning.pytorch.strategies.launchers import _MultiProcessingLauncher, _SubprocessScriptLauncher
4444
from lightning.pytorch.strategies.parallel import ParallelStrategy
4545
from lightning.pytorch.strategies.strategy import TBroadcast
4646
from lightning.pytorch.trainer.states import TrainerFn
47-
from lightning.pytorch.utilities.distributed import _register_ddp_comm_hook
4847
from lightning.pytorch.utilities.exceptions import _augment_message
4948
from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_only
5049
from lightning.pytorch.utilities.types import PredictStep, STEP_OUTPUT, TestStep, ValidationStep
@@ -153,25 +152,28 @@ def setup(self, trainer: "pl.Trainer") -> None:
153152
# skip wrapping the model if we are not fitting as no gradients need to be exchanged
154153
trainer_fn = trainer.state.fn
155154

156-
if trainer_fn == TrainerFn.FITTING:
157-
if self._layer_sync:
158-
assert self.model is not None
159-
self.model = self._layer_sync.apply(self.model)
155+
if trainer_fn == TrainerFn.FITTING and self._layer_sync:
156+
assert self.model is not None
157+
self.model = self._layer_sync.apply(self.model)
160158

161159
self.setup_precision_plugin()
162160

163161
if trainer_fn == TrainerFn.FITTING:
162+
# do not wrap with DDP if not fitting as there's no gradients to reduce
164163
self.configure_ddp()
165164

166165
# set up optimizers after the wrapped module has been moved to the device
167166
self.setup_optimizers(trainer)
168167
_optimizers_to_device(self.optimizers, self.root_device)
169168

170-
if trainer_fn == TrainerFn.FITTING:
171169
import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD
172170

173171
if isinstance(self._ddp_comm_state, post_localSGD.PostLocalSGDState):
174172
self._enable_model_averaging()
173+
else:
174+
# we need to manually synchronize the module's states since we aren't using the DDP wrapper
175+
assert self.model is not None
176+
_sync_module_states(self.model)
175177

176178
def _setup_model(self, model: Module) -> DistributedDataParallel:
177179
"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""

src/lightning/pytorch/utilities/distributed.py

Lines changed: 0 additions & 144 deletions
This file was deleted.

0 commit comments

Comments
 (0)