12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
import itertools
15
- from typing import Any , cast , Dict , Iterable , Iterator , List , Optional , Sized , Union
15
+ from typing import Any , Callable , cast , Dict , Iterable , Iterator , List , Optional , Sized , Union
16
16
17
17
import torch
18
18
from torch import Tensor
19
- from torch .nn .parallel import DistributedDataParallel
19
+ from torch .nn .parallel . distributed import DistributedDataParallel
20
20
from torch .utils .data import BatchSampler , DistributedSampler , Sampler
21
21
22
22
from lightning .fabric .utilities .distributed import _DatasetSamplerWrapper
23
+ from lightning .fabric .utilities .imports import _TORCH_GREATER_EQUAL_1_12
24
+ from lightning .pytorch .utilities .rank_zero import rank_zero_debug , rank_zero_info
23
25
24
26
25
27
def _find_tensors (
@@ -37,7 +39,7 @@ def _find_tensors(
37
39
38
40
# In manual_optimization, we need to call reducer prepare_for_backward.
39
41
# Note: Keep track of PyTorch DDP and update if there is a change
40
- # https://github.com/pytorch/pytorch/blob/v1.7.1 /torch/nn/parallel/distributed.py#L626-L638
42
+ # https://github.com/pytorch/pytorch/blob/v2.0.0 /torch/nn/parallel/distributed.py#L1163-L1178
41
43
def prepare_for_backward (model : DistributedDataParallel , output : Any ) -> None :
42
44
# `prepare_for_backward` is `DistributedDataParallel` specific.
43
45
if torch .is_grad_enabled () and model .require_backward_grad_sync :
@@ -47,14 +49,143 @@ def prepare_for_backward(model: DistributedDataParallel, output: Any) -> None:
47
49
# because we need to figure out which parameters were used during
48
50
# this forward pass, to ensure we short circuit reduction for any
49
51
# unused parameters. Only if `find_unused_parameters` is set.
50
- args = list (_find_tensors (output )) if model .find_unused_parameters else []
52
+ args = list (_find_tensors (output )) if model .find_unused_parameters and not model . static_graph else []
51
53
reducer = cast (torch ._C ._distributed_c10d .Reducer , model .reducer )
52
54
reducer ._rebuild_buckets () # avoids "INTERNAL ASSERT FAILED" with `find_unused_parameters=False`
53
55
reducer .prepare_for_backward (args )
54
56
else :
55
57
model .require_forward_param_sync = False
56
58
57
59
60
+ def _register_ddp_comm_hook (
61
+ model : DistributedDataParallel ,
62
+ ddp_comm_state : Optional [object ] = None ,
63
+ ddp_comm_hook : Optional [Callable ] = None ,
64
+ ddp_comm_wrapper : Optional [Callable ] = None ,
65
+ ) -> None :
66
+ """Function to register communication hook for DDP model https://pytorch.org/docs/master/ddp_comm_hooks.html.
67
+
68
+ Args:
69
+ model:
70
+ DDP model
71
+ ddp_comm_state:
72
+ state is passed to the hook and can be used to maintain
73
+ and update any state information that users would like to
74
+ maintain as part of the training process. Examples: error
75
+ feedback in gradient compression, peers to communicate with
76
+ next in GossipGrad etc.
77
+ ddp_comm_hook:
78
+ hook(state: object, bucket: dist._GradBucket) -> torch.futures.Future
79
+
80
+ This callable function is called once the bucket is ready. The
81
+ hook can perform whatever processing is needed and return
82
+ a Future indicating completion of any async work (ex: allreduce).
83
+ If the hook doesn't perform any communication, it can also
84
+ just return a completed Future. The Future should hold the
85
+ new value of grad bucket's tensors. Once a bucket is ready,
86
+ c10d reducer would call this hook and use the tensors returned
87
+ by the Future and copy grads to individual parameters.
88
+ ddp_comm_wrapper:
89
+ communication hook wrapper to support a communication hook such
90
+ as FP16 compression as wrapper, which could be combined with
91
+ ddp_comm_hook
92
+
93
+ Examples:
94
+
95
+ >>> from torch.distributed.algorithms.ddp_comm_hooks import ( # doctest: +SKIP
96
+ ... default_hooks as default,
97
+ ... powerSGD_hook as powerSGD,
98
+ ... post_localSGD_hook as post_localSGD,
99
+ ... )
100
+ >>> # fp16_compress_hook for compress gradients
101
+ >>> ddp_model = ...
102
+ >>> _register_ddp_comm_hook( # doctest: +SKIP
103
+ ... model=ddp_model,
104
+ ... ddp_comm_hook=default.fp16_compress_hook,
105
+ ... )
106
+ >>> # powerSGD_hook
107
+ >>> ddp_model = ...
108
+ >>> _register_ddp_comm_hook( # doctest: +SKIP
109
+ ... model=ddp_model,
110
+ ... ddp_comm_state=powerSGD.PowerSGDState(
111
+ ... process_group=None,
112
+ ... matrix_approximation_rank=1,
113
+ ... start_powerSGD_iter=5000,
114
+ ... ),
115
+ ... ddp_comm_hook=powerSGD.powerSGD_hook,
116
+ ... )
117
+ >>> # post_localSGD_hook
118
+ >>> subgroup, _ = torch.distributed.new_subgroups() # doctest: +SKIP
119
+ >>> ddp_model = ...
120
+ >>> _register_ddp_comm_hook( # doctest: +SKIP
121
+ ... model=ddp_model,
122
+ ... state=post_localSGD.PostLocalSGDState(
123
+ ... process_group=None,
124
+ ... subgroup=subgroup,
125
+ ... start_localSGD_iter=1_000,
126
+ ... ),
127
+ ... ddp_comm_hook=post_localSGD.post_localSGD_hook,
128
+ ... )
129
+ >>> # fp16_compress_wrapper combined with other communication hook
130
+ >>> ddp_model = ...
131
+ >>> _register_ddp_comm_hook( # doctest: +SKIP
132
+ ... model=ddp_model,
133
+ ... ddp_comm_state=powerSGD.PowerSGDState(
134
+ ... process_group=None,
135
+ ... matrix_approximation_rank=1,
136
+ ... start_powerSGD_iter=5000,
137
+ ... ),
138
+ ... ddp_comm_hook=powerSGD.powerSGD_hook,
139
+ ... ddp_comm_wrapper=default.fp16_compress_wrapper,
140
+ ... )
141
+ """
142
+ if ddp_comm_hook is None :
143
+ return
144
+ # inform mypy that ddp_comm_hook is callable
145
+ ddp_comm_hook : Callable = ddp_comm_hook
146
+
147
+ if ddp_comm_wrapper is not None :
148
+ rank_zero_info (
149
+ f"DDP comm wrapper is provided, apply { ddp_comm_wrapper .__qualname__ } ({ ddp_comm_hook .__qualname__ } )."
150
+ )
151
+ ddp_comm_hook = ddp_comm_wrapper (ddp_comm_hook )
152
+
153
+ rank_zero_debug (f"Registering DDP comm hook: { ddp_comm_hook .__qualname__ } ." )
154
+ model .register_comm_hook (state = ddp_comm_state , hook = ddp_comm_hook )
155
+
156
+
157
+ def _sync_module_states (module : torch .nn .Module ) -> None :
158
+ """Taken from https://github.com/pytorch/pytorch/blob/v2.0.0/torch/nn/parallel/distributed.py#L675-L682."""
159
+ parameters_to_ignore = (
160
+ set (module ._ddp_params_and_buffers_to_ignore ) # type: ignore[arg-type]
161
+ if hasattr (module , "_ddp_params_and_buffers_to_ignore" )
162
+ else set ()
163
+ )
164
+ from torch .distributed .distributed_c10d import _get_default_group
165
+
166
+ if not _TORCH_GREATER_EQUAL_1_12 :
167
+ module_states = []
168
+ for name , param in module .named_parameters ():
169
+ if name not in parameters_to_ignore :
170
+ module_states .append (param .detach ())
171
+ for name , buffer in module .named_buffers ():
172
+ if name not in parameters_to_ignore :
173
+ module_states .append (buffer .detach ())
174
+ if len (module_states ) > 0 :
175
+ torch .distributed ._broadcast_coalesced (_get_default_group (), module_states , 250 * 1024 * 1024 , 0 )
176
+ return
177
+
178
+ from torch .distributed .utils import _sync_module_states as torch_sync_module_states
179
+
180
+ torch_sync_module_states (
181
+ module ,
182
+ _get_default_group (),
183
+ 250 * 1024 * 1024 ,
184
+ src = 0 ,
185
+ params_and_buffers_to_ignore = parameters_to_ignore ,
186
+ )
187
+
188
+
58
189
class UnrepeatedDistributedSampler (DistributedSampler ):
59
190
"""A fork of the PyTorch DistributedSampler that doesn't repeat data, instead allowing the number of batches
60
191
per process to be off-by-one from each other. This makes this sampler usable for predictions (it's
0 commit comments