Skip to content

Commit b93142d

Browse files
committed
enable merging parameters for diloco
1 parent 4b991c8 commit b93142d

File tree

1 file changed

+31
-15
lines changed

1 file changed

+31
-15
lines changed

torchft/local_sgd.py

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -213,8 +213,14 @@ def __init__(
213213
self.should_quantize = should_quantize
214214

215215
self._grads: Dict[str, torch.Tensor] = {}
216+
217+
# Used to save global parameters so that they can be restored in case
218+
# commit fails
216219
self.original_parameters: Dict[str, torch.Tensor] = {}
217220

221+
# Used to mix the local and global parameters
222+
self._local_parameters: Dict[str, torch.Tensor] = {}
223+
218224
for name, p in self._model_fragment.named_parameters():
219225
if isinstance(p, DTensor):
220226
p = extract_local_tensor(p.data)
@@ -240,21 +246,22 @@ def save_parameters(self) -> None:
240246
@torch.profiler.record_function("torchft::local_sgd::restore_parameters")
241247
def restore_parameters(self) -> None:
242248
with torch.no_grad():
249+
assert len(self._local_parameters) == 0
243250
# TODO: consider running copy on a separate stream
244251
for name, p in self._model_fragment.named_parameters():
252+
self._local_parameters[name] = p.data
253+
245254
if isinstance(p, DTensor):
246255
# we averaged the local version of the tensor so need to copy it back as a DTensor
247-
p.data.copy_(
248-
DTensor.from_local(
249-
self.original_parameters[name],
250-
p.device_mesh,
251-
p.placements,
252-
shape=p.shape,
253-
stride=p.stride(),
254-
),
255-
non_blocking=False,
256+
p.data = DTensor.from_local(
257+
self.original_parameters[name],
258+
p.device_mesh,
259+
p.placements,
260+
shape=p.shape,
261+
stride=p.stride(),
256262
)
257263
else:
264+
p.data = torch.empty_like(self.original_parameters[name])
258265
p.data.copy_(self.original_parameters[name], non_blocking=False)
259266

260267
def _set_grads(self) -> None:
@@ -269,6 +276,18 @@ def _set_grads(self) -> None:
269276

270277
del self._grads[name]
271278

279+
def _merge_parameters(self) -> None:
280+
"""
281+
Merges the local and global parameters.
282+
"""
283+
for name, p in self._model_fragment.named_parameters():
284+
torch.lerp(
285+
p.data, self._local_parameters[name], 1 - self._fragment_update_alpha
286+
)
287+
288+
# we don't need the local parameters anymore
289+
self._local_parameters = {}
290+
272291
@torch.profiler.record_function("torchft::local_sgd::wait")
273292
def wait(self) -> None:
274293
"""
@@ -313,6 +332,8 @@ def prepare_sync(self) -> None:
313332
else:
314333
self._grads[name] = pseudogradient
315334

335+
assert len(self._allreduce_futures) == 0
336+
316337
# Make sure tensors are available to `_stream`
317338
if self._stream is not None:
318339
self._stream.wait_stream(torch.cuda.current_stream())
@@ -352,6 +373,7 @@ def perform_sync(self) -> bool:
352373
self._set_grads()
353374
self._outer_optimizer.step()
354375
self.save_parameters()
376+
self._merge_parameters()
355377
self._outer_optimizer.zero_grad()
356378

357379
return should_commit
@@ -512,12 +534,6 @@ def __init__(
512534
if fragment_update_alpha < 0 or fragment_update_alpha > 1:
513535
raise ValueError("fragment_update_alpha must be between 0 and 1")
514536

515-
# TODO: Support `fragment_update_alpha`
516-
if fragment_update_alpha != 0.0:
517-
raise ValueError(
518-
"Merging local parameters with global parameters is not supported yet"
519-
)
520-
521537
super().__init__()
522538
self._manager = manager
523539

0 commit comments

Comments
 (0)