Skip to content

Commit d1c0fc8

Browse files
committed
enable merging parameters for diloco
1 parent 1f93550 commit d1c0fc8

File tree

1 file changed

+35
-6
lines changed

1 file changed

+35
-6
lines changed

torchft/local_sgd.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -213,8 +213,14 @@ def __init__(
213213
self.should_quantize = should_quantize
214214

215215
self._grads: Dict[str, torch.Tensor] = {}
216+
217+
# Used to save global parameters so that they can be restored in case
218+
# commit fails
216219
self.original_parameters: Dict[str, torch.Tensor] = {}
217220

221+
# Used to mix the local and global parameters
222+
self._local_parameters: Dict[str, torch.Tensor] = {}
223+
218224
for name, p in self._model_fragment.named_parameters():
219225
if isinstance(p, DTensor):
220226
p = extract_local_tensor(p.data)
@@ -237,6 +243,14 @@ def save_parameters(self) -> None:
237243
param_to_local = extract_local_tensor(p.data)
238244
self.original_parameters[name].copy_(param_to_local, non_blocking=True)
239245

246+
def _save_local_parameters(self) -> None:
247+
"""
248+
Saves a copy of the model's parameters.
249+
"""
250+
with torch.no_grad():
251+
for name, p in self._model_fragment.named_parameters():
252+
self._local_parameters[name] = extract_local_tensor(p.data)
253+
240254
@torch.profiler.record_function("torchft::local_sgd::restore_parameters")
241255
def restore_parameters(self) -> None:
242256
with torch.no_grad():
@@ -284,6 +298,21 @@ def _set_grads(self) -> None:
284298
else:
285299
p.grad = self._grads[name]
286300

301+
def _clear_local_parameters(self) -> None:
302+
"""
303+
Clears the saved copy of the model's parameters
304+
"""
305+
self._local_parameters = {}
306+
307+
def _merge_parameters(self) -> None:
308+
"""
309+
Merges the local and global parameters.
310+
"""
311+
for name, p in self._model_fragment.named_parameters():
312+
torch.lerp(
313+
p.data, self._local_parameters[name], 1 - self._fragment_update_alpha
314+
)
315+
287316
@torch.profiler.record_function("torchft::local_sgd::wait")
288317
def wait(self) -> None:
289318
"""
@@ -352,6 +381,8 @@ def perform_sync(self) -> bool:
352381

353382
self.wait()
354383

384+
# save the parameters so they can be used for merging
385+
self._save_local_parameters()
355386
# Restore the parameters back to the previous state
356387
self.restore_parameters()
357388

@@ -362,8 +393,12 @@ def perform_sync(self) -> bool:
362393
self._set_grads()
363394
self._outer_optimizer.step()
364395
self.save_parameters()
396+
self._merge_parameters()
365397
self._outer_optimizer.zero_grad()
366398

399+
# free up memory
400+
self._clear_local_parameters()
401+
367402
return should_commit
368403

369404
def _average_grads(self) -> None:
@@ -515,12 +550,6 @@ def __init__(
515550
if fragment_update_alpha < 0 or fragment_update_alpha > 1:
516551
raise ValueError("fragment_update_alpha must be between 0 and 1")
517552

518-
# TODO: Support `fragment_update_alpha`
519-
if fragment_update_alpha != 0.0:
520-
raise ValueError(
521-
"Merging local parameters with global parameters is not supported yet"
522-
)
523-
524553
super().__init__()
525554
self._manager = manager
526555

0 commit comments

Comments
 (0)