-
Notifications
You must be signed in to change notification settings - Fork 33
Generalize DiLoCo to support Streaming #205
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,6 +9,7 @@ | |
This module implements a fault tolerant version of LocalSGD and related methods. | ||
""" | ||
import logging | ||
import math | ||
from types import TracebackType | ||
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Type | ||
|
||
|
@@ -158,70 +159,44 @@ def _average(self) -> list[torch.Tensor]: | |
return averaged_parameters | ||
|
||
|
||
class DiLoCo: | ||
""" | ||
DiLoCo is a subclass of LocalSGD that overrides the synchronization | ||
mechanism to average and synchronize the pseudogradients (delta of the previous global weight and current local weights). | ||
|
||
This algorithm requires a backup copy of the | ||
weights. By default these are stored in CPU memory. If any error occurs | ||
during the DiLoCo step, the step will be discarded and the model | ||
parameters will reset back to the last time DiLoCo synchronized. | ||
|
||
DiLoCo paper: https://arxiv.org/pdf/2311.08105 | ||
""" | ||
|
||
class _StreamingDiLoCoFragment: | ||
bucket_cap_mb: int = 32 * 1024 * 1024 | ||
use_bucketization: bool = False | ||
|
||
def __init__( | ||
self, | ||
manager: Manager, | ||
model: nn.Module, | ||
model_fragment: nn.Module, | ||
fragment_sync_offset: int, | ||
inner_optimizer: optim.Optimizer, | ||
outer_optimizer: optim.Optimizer, | ||
sync_every: int, | ||
backup_device: Optional[torch.device] = None, | ||
pin_memory: bool = True, | ||
use_bucketization: bool = False, | ||
bucket_cap_mb: Optional[int] = None, | ||
should_quantize: bool = False, | ||
) -> None: | ||
""" | ||
Args: | ||
manager: The manager to use. | ||
model: The model to wrap. | ||
inner_optimizer: The optimizer used for the local parameters every step. | ||
outer_optimizer: The optimizer used for the global parameters updated every "sync_every" steps. | ||
sync_every: How often to update the model weights. | ||
backup_device: The device to store the backup weights on. If None, the backup weights will be on CPU. | ||
pin_memory: Whether to pin the memory for the backup weights (only for CPU device). | ||
""" | ||
|
||
if manager._use_async_quorum: | ||
raise ValueError( | ||
"Using DiLoCo require synchronous quorum to be enabled. " | ||
"Ensure that the manager is initialized with use_async_quorum=False" | ||
) | ||
super().__init__() | ||
self._manager = manager | ||
self._model = model | ||
self._model_fragment = model_fragment | ||
self._fragment_sync_offset = fragment_sync_offset | ||
self._local_optimizer = inner_optimizer | ||
self._local_step = 0 | ||
self._sync_every = sync_every | ||
assert sync_every >= 1, "sync_every must be greater than or equal to 1" | ||
self._backup_device = backup_device | ||
self._pin_memory = pin_memory | ||
|
||
self._hooks: List[RemovableHandle] = [] | ||
self._outer_optimizer = outer_optimizer | ||
|
||
if bucket_cap_mb is not None: | ||
self.bucket_cap_mb = int(bucket_cap_mb * 1024 * 1024) | ||
|
||
self.use_bucketization = use_bucketization | ||
self.should_quantize = should_quantize | ||
|
||
self.original_parameters: Dict[str, torch.Tensor] = {} | ||
for name, p in self._model.named_parameters(): | ||
for name, p in self._model_fragment.named_parameters(): | ||
if isinstance(p, DTensor): | ||
p = extract_local_tensor(p.data) | ||
|
||
|
@@ -235,20 +210,17 @@ def __init__( | |
t = t.pin_memory() | ||
self.original_parameters[name] = t | ||
|
||
# Need to copy the parameters to the host to be safe if we are on the first step. | ||
self._save_parameters() | ||
|
||
def _save_parameters(self) -> None: | ||
def save_parameters(self) -> None: | ||
with torch.no_grad(): | ||
# TODO: consider running copy on a separate stream | ||
for name, p in self._model.named_parameters(): | ||
for name, p in self._model_fragment.named_parameters(): | ||
param_to_local = extract_local_tensor(p.data) | ||
self.original_parameters[name].copy_(param_to_local, non_blocking=True) | ||
|
||
def _restore_parameters(self) -> None: | ||
def restore_parameters(self) -> None: | ||
with torch.no_grad(): | ||
# TODO: consider running copy on a separate stream | ||
for name, p in self._model.named_parameters(): | ||
for name, p in self._model_fragment.named_parameters(): | ||
if isinstance(p, DTensor): | ||
# we averaged the local version of the tensor so need to copy it back as a DTensor | ||
p.data.copy_( | ||
|
@@ -264,52 +236,25 @@ def _restore_parameters(self) -> None: | |
else: | ||
p.data.copy_(self.original_parameters[name], non_blocking=False) | ||
|
||
def __enter__(self) -> "DiLoCo": | ||
# Add optimizer hook which increments the local step counter and syncs if necessary | ||
self._hooks.append( | ||
self._local_optimizer.register_step_post_hook(self._step_post_hook) | ||
) | ||
return self | ||
|
||
def __exit__( | ||
self, | ||
exc_type: Optional[Type[BaseException]], | ||
exc_value: Optional[BaseException], | ||
traceback: Optional[TracebackType], | ||
) -> bool: | ||
# Handle any cleanup or error handling here | ||
# Clean up hooks | ||
for hook in self._hooks: | ||
hook.remove() | ||
self._hooks.clear() | ||
|
||
return False # Propagate exceptions | ||
|
||
def _step_post_hook( | ||
self, _optim: optim.Optimizer, _args: Tuple[Any, ...], _kwargs: Dict[str, Any] | ||
) -> None: | ||
""" | ||
This hook is registered on the optimizer and is called after the optimizer step. | ||
""" | ||
self._local_step += 1 | ||
if self._local_step >= self._sync_every: | ||
self.sync() | ||
|
||
def sync(self) -> None: | ||
""" | ||
Synchronizes and averages the model weights across the manager. | ||
""" | ||
self._local_step += 1 | ||
|
||
if (self._local_step - self._fragment_sync_offset) % self._sync_every != 0: | ||
return | ||
|
||
self._manager.start_quorum() | ||
self._perform_sync() | ||
self._local_step = 0 | ||
|
||
def _perform_sync(self) -> None: | ||
""" | ||
Overrides the sync method to calculate the pseugradient, average them across the manager group, and | ||
step using the outer optimizer. | ||
""" | ||
# Set the .grad field of each parameter to its pseudogradient | ||
for name, p in self._model.named_parameters(): | ||
for name, p in self._model_fragment.named_parameters(): | ||
local_param = extract_local_tensor(p.data) | ||
pseudogradient = local_param - self.original_parameters[name].to(p.device) | ||
if isinstance(p, DTensor): | ||
|
@@ -319,11 +264,11 @@ def _perform_sync(self) -> None: | |
|
||
self._average_grads() | ||
# Restore the parameters back to the previous state | ||
self._restore_parameters() | ||
self.restore_parameters() | ||
if self._manager.should_commit(): | ||
# Use the outer optimizer to update the model parameters | ||
self._outer_optimizer.step() | ||
self._save_parameters() | ||
self.save_parameters() | ||
self._outer_optimizer.zero_grad() | ||
|
||
def _average_grads(self) -> None: | ||
|
@@ -340,13 +285,17 @@ def _average_grads(self) -> None: | |
def _allreduce_per_param(self) -> None: | ||
"""Performs allreduce on each gradient tensor separately (original method).""" | ||
works = [] | ||
for p in self._model.parameters(): | ||
for p in self._model_fragment.parameters(): | ||
# Perform allreduce on the pseudogradients | ||
assert p.grad is not None | ||
if isinstance(p, DTensor): | ||
work = self._manager.allreduce(p.grad._local_tensor) | ||
work = self._manager.allreduce( | ||
p.grad._local_tensor, should_quantize=self.should_quantize | ||
) | ||
else: | ||
work = self._manager.allreduce(p.grad) | ||
work = self._manager.allreduce( | ||
p.grad, should_quantize=self.should_quantize | ||
) | ||
works.append(work) | ||
|
||
for work in works: | ||
|
@@ -388,7 +337,9 @@ def bucketize_and_allreduce( | |
pack_offset += numel | ||
flat_index += 1 | ||
|
||
work = self._manager.allreduce(flat_buffer) | ||
work = self._manager.allreduce( | ||
flat_buffer, should_quantize=self.should_quantize | ||
) | ||
work.wait() | ||
|
||
for t, pack_offset, numel in bucket_tensors: | ||
|
@@ -400,8 +351,148 @@ def _allreduce_bucketized(self) -> None: | |
""" | ||
Averages gradients using bucketized allreduce with a fixed buffer. | ||
""" | ||
grads = [p.grad for p in self._model.parameters() if p.grad is not None] | ||
grads = [ | ||
p.grad for p in self._model_fragment.parameters() if p.grad is not None | ||
] | ||
self.bucketize_and_allreduce( | ||
grads, | ||
bucket_size_bytes=self.bucket_cap_mb, | ||
) | ||
|
||
|
||
class DiLoCo: | ||
""" | ||
DiLoCo is a subclass of LocalSGD that overrides the synchronization | ||
mechanism to average and synchronize the pseudogradients (delta of the previous global weight and current local weights). | ||
|
||
The class implements a more general version of DiLoco, Streaming DiLoCo, | ||
which synchronizes fragments of pseudogradients at different steps. | ||
|
||
This algorithm requires a backup copy of the | ||
weights. By default these are stored in CPU memory. If any error occurs | ||
during the DiLoCo step, the step will be discarded and the model | ||
parameters will reset back to the last time DiLoCo synchronized. | ||
|
||
DiLoCo paper: https://arxiv.org/pdf/2311.08105 | ||
Streaming DiLoCo paper: https://arxiv.org/pdf/2501.18512 | ||
""" | ||
|
||
def __init__( | ||
self, | ||
manager: Manager, | ||
model_fragments: List[nn.Module], | ||
inner_optimizer: optim.Optimizer, | ||
outer_optimizer: optim.Optimizer, | ||
sync_every: int, | ||
backup_device: Optional[torch.device] = None, | ||
pin_memory: bool = True, | ||
use_bucketization: bool = False, | ||
bucket_cap_mb: Optional[int] = None, | ||
should_quantize: bool = False, | ||
fragment_sync_delay: int = 0, | ||
fragment_update_alpha: float = 0.0, | ||
) -> None: | ||
""" | ||
Args: | ||
manager: The manager to use. | ||
model_fragments: The fragments of the model to wrap. | ||
inner_optimizer: The optimizer used for the local parameters every step. | ||
outer_optimizer: The optimizer used for the global parameters updated every "sync_every" steps. | ||
sync_every: How often to update the model weights. | ||
backup_device: The device to store the backup weights on. If None, the backup weights will be on CPU. | ||
pin_memory: Whether to pin the memory for the backup weights (only for CPU device). | ||
should_quantize: Whether to quantize the gradients before allreduce. | ||
fragment_sync_delay: Controls the number of inner steps to wait before blocking on a fragment's | ||
synchronization. This is the "tao" parameter in the Streaming DiLoCo paper. | ||
fragment_update_alpha: Determines how to mix the local and global optimized parameters | ||
""" | ||
|
||
if manager._use_async_quorum: | ||
raise ValueError( | ||
"Using DiLoCo require synchronous quorum to be enabled. " | ||
"Ensure that the manager is initialized with use_async_quorum=False" | ||
) | ||
|
||
if sync_every < len(model_fragments): | ||
raise ValueError("Only 1 fragment can be syncrhonized at a time") | ||
|
||
# TODO: Support multiple fragments | ||
# This requires changing the manager to support `should_commit` for each | ||
# fragment separately. | ||
if len(model_fragments) != 1: | ||
raise ValueError("Multiple fragments are not supported yet") | ||
|
||
# TODO: Support `fragment_sync_delay` | ||
if fragment_sync_delay != 0: | ||
raise ValueError("Fragment synchronization delay is not supported yet") | ||
|
||
# TODO: Support `fragment_update_alpha` | ||
if fragment_update_alpha != 0.0: | ||
raise ValueError( | ||
"Merging local parameters with global parameters is not supported yet" | ||
) | ||
|
||
super().__init__() | ||
|
||
self._hooks: List[RemovableHandle] = [] | ||
|
||
self._local_optimizer = inner_optimizer | ||
|
||
self._fragments: List[_StreamingDiLoCoFragment] = [ | ||
_StreamingDiLoCoFragment( | ||
manager, | ||
model_fragment, | ||
math.floor((sync_every / len(model_fragments)) * (i + 1)), | ||
inner_optimizer, | ||
# TODO: Support different outer optimizers for each fragment | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh interesting, is that mentioned in the paper? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think they should be different otherwise things like momentum end up being the same for all fragments? Maybe it's not very important though |
||
outer_optimizer, | ||
sync_every, | ||
backup_device, | ||
pin_memory, | ||
use_bucketization, | ||
bucket_cap_mb, | ||
should_quantize, | ||
) | ||
for i, model_fragment in enumerate(model_fragments) | ||
] | ||
|
||
# Need to copy the parameters to the host to be safe if we are on the first step. | ||
self._save_parameters() | ||
|
||
def _save_parameters(self) -> None: | ||
for fragment in self._fragments: | ||
fragment.save_parameters() | ||
|
||
def _restore_parameters(self) -> None: | ||
for fragment in self._fragments: | ||
fragment.restore_parameters() | ||
|
||
def __enter__(self) -> "DiLoCo": | ||
# Add optimizer hook which increments the local step counter and syncs if necessary | ||
self._hooks.append( | ||
self._local_optimizer.register_step_post_hook(self._step_post_hook) | ||
) | ||
return self | ||
|
||
def __exit__( | ||
self, | ||
exc_type: Optional[Type[BaseException]], | ||
exc_value: Optional[BaseException], | ||
traceback: Optional[TracebackType], | ||
) -> bool: | ||
# Handle any cleanup or error handling here | ||
# Clean up hooks | ||
for hook in self._hooks: | ||
hook.remove() | ||
self._hooks.clear() | ||
|
||
return False # Propagate exceptions | ||
|
||
def _step_post_hook( | ||
self, _optim: optim.Optimizer, _args: Tuple[Any, ...], _kwargs: Dict[str, Any] | ||
) -> None: | ||
""" | ||
This hook is registered on the optimizer and is called after the optimizer step. | ||
""" | ||
for fragment in self._fragments: | ||
fragment.sync() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: maybe from an API / UX perspective we can support
nn.Module | List[nn.Module]
with the specification that passing in a single nn.Module means whole model.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Think it's better to have a smaller api surface and avoid having a special case?