@@ -213,8 +213,14 @@ def __init__(
213
213
self .should_quantize = should_quantize
214
214
215
215
self ._grads : Dict [str , torch .Tensor ] = {}
216
+
217
+ # Used to save global parameters so that they can be restored in case
218
+ # commit fails
216
219
self .original_parameters : Dict [str , torch .Tensor ] = {}
217
220
221
+ # Used to mix the local and global parameters
222
+ self ._local_parameters : Dict [str , torch .Tensor ] = {}
223
+
218
224
for name , p in self ._model_fragment .named_parameters ():
219
225
if isinstance (p , DTensor ):
220
226
p = extract_local_tensor (p .data )
@@ -240,21 +246,22 @@ def save_parameters(self) -> None:
240
246
@torch .profiler .record_function ("torchft::local_sgd::restore_parameters" )
241
247
def restore_parameters (self ) -> None :
242
248
with torch .no_grad ():
249
+ assert len (self ._local_parameters ) == 0
243
250
# TODO: consider running copy on a separate stream
244
251
for name , p in self ._model_fragment .named_parameters ():
252
+ self ._local_parameters [name ] = p .data
253
+
245
254
if isinstance (p , DTensor ):
246
255
# we averaged the local version of the tensor so need to copy it back as a DTensor
247
- p .data .copy_ (
248
- DTensor .from_local (
249
- self .original_parameters [name ],
250
- p .device_mesh ,
251
- p .placements ,
252
- shape = p .shape ,
253
- stride = p .stride (),
254
- ),
255
- non_blocking = False ,
256
+ p .data = DTensor .from_local (
257
+ self .original_parameters [name ],
258
+ p .device_mesh ,
259
+ p .placements ,
260
+ shape = p .shape ,
261
+ stride = p .stride (),
256
262
)
257
263
else :
264
+ p .data = torch .empty_like (self .original_parameters [name ])
258
265
p .data .copy_ (self .original_parameters [name ], non_blocking = False )
259
266
260
267
def _set_grads (self ) -> None :
@@ -269,6 +276,18 @@ def _set_grads(self) -> None:
269
276
270
277
del self ._grads [name ]
271
278
279
+ def _merge_parameters (self ) -> None :
280
+ """
281
+ Merges the local and global parameters.
282
+ """
283
+ for name , p in self ._model_fragment .named_parameters ():
284
+ torch .lerp (
285
+ p .data , self ._local_parameters [name ], 1 - self ._fragment_update_alpha
286
+ )
287
+
288
+ # we don't need the local parameters anymore
289
+ self ._local_parameters = {}
290
+
272
291
@torch .profiler .record_function ("torchft::local_sgd::wait" )
273
292
def wait (self ) -> None :
274
293
"""
@@ -313,6 +332,8 @@ def prepare_sync(self) -> None:
313
332
else :
314
333
self ._grads [name ] = pseudogradient
315
334
335
+ assert len (self ._allreduce_futures ) == 0
336
+
316
337
# Make sure tensors are available to `_stream`
317
338
if self ._stream is not None :
318
339
self ._stream .wait_stream (torch .cuda .current_stream ())
@@ -352,6 +373,7 @@ def perform_sync(self) -> bool:
352
373
self ._set_grads ()
353
374
self ._outer_optimizer .step ()
354
375
self .save_parameters ()
376
+ self ._merge_parameters ()
355
377
self ._outer_optimizer .zero_grad ()
356
378
357
379
return should_commit
@@ -512,12 +534,6 @@ def __init__(
512
534
if fragment_update_alpha < 0 or fragment_update_alpha > 1 :
513
535
raise ValueError ("fragment_update_alpha must be between 0 and 1" )
514
536
515
- # TODO: Support `fragment_update_alpha`
516
- if fragment_update_alpha != 0.0 :
517
- raise ValueError (
518
- "Merging local parameters with global parameters is not supported yet"
519
- )
520
-
521
537
super ().__init__ ()
522
538
self ._manager = manager
523
539
0 commit comments