Skip to content

Generalize DiLoCo to support Streaming #205

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions torchft/collectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import TYPE_CHECKING

import torch

# pyre-ignore[21]: Could not find a module corresponding to import `triton`
Expand All @@ -17,7 +19,9 @@
)
from torch.futures import Future

from torchft.process_group import ProcessGroup
if TYPE_CHECKING:
from torchft.process_group import ProcessGroup

from torchft.quantization import (
fused_dequantize_from_fp8,
fused_quantize_into_fp8,
Expand All @@ -40,7 +44,7 @@ def _to_allgather_options(opts: AllreduceOptions) -> AllgatherOptions:
def allreduce_quantized(
tensors: list[torch.Tensor],
opts: AllreduceOptions | ReduceOp,
process_group: ProcessGroup,
process_group: "ProcessGroup",
sync_stream: cuda.Stream | None = None,
) -> Future[None]:
"""
Expand Down
253 changes: 172 additions & 81 deletions torchft/local_sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
This module implements a fault tolerant version of LocalSGD and related methods.
"""
import logging
import math
from types import TracebackType
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Type

Expand Down Expand Up @@ -158,70 +159,44 @@ def _average(self) -> list[torch.Tensor]:
return averaged_parameters


class DiLoCo:
"""
DiLoCo is a subclass of LocalSGD that overrides the synchronization
mechanism to average and synchronize the pseudogradients (delta of the previous global weight and current local weights).

This algorithm requires a backup copy of the
weights. By default these are stored in CPU memory. If any error occurs
during the DiLoCo step, the step will be discarded and the model
parameters will reset back to the last time DiLoCo synchronized.

DiLoCo paper: https://arxiv.org/pdf/2311.08105
"""

class _StreamingDiLoCoFragment:
bucket_cap_mb: int = 32 * 1024 * 1024
use_bucketization: bool = False

def __init__(
self,
manager: Manager,
model: nn.Module,
model_fragment: nn.Module,
fragment_sync_offset: int,
inner_optimizer: optim.Optimizer,
outer_optimizer: optim.Optimizer,
sync_every: int,
backup_device: Optional[torch.device] = None,
pin_memory: bool = True,
use_bucketization: bool = False,
bucket_cap_mb: Optional[int] = None,
should_quantize: bool = False,
) -> None:
"""
Args:
manager: The manager to use.
model: The model to wrap.
inner_optimizer: The optimizer used for the local parameters every step.
outer_optimizer: The optimizer used for the global parameters updated every "sync_every" steps.
sync_every: How often to update the model weights.
backup_device: The device to store the backup weights on. If None, the backup weights will be on CPU.
pin_memory: Whether to pin the memory for the backup weights (only for CPU device).
"""

if manager._use_async_quorum:
raise ValueError(
"Using DiLoCo require synchronous quorum to be enabled. "
"Ensure that the manager is initialized with use_async_quorum=False"
)
super().__init__()
self._manager = manager
self._model = model
self._model_fragment = model_fragment
self._fragment_sync_offset = fragment_sync_offset
self._local_optimizer = inner_optimizer
self._local_step = 0
self._sync_every = sync_every
assert sync_every >= 1, "sync_every must be greater than or equal to 1"
self._backup_device = backup_device
self._pin_memory = pin_memory

self._hooks: List[RemovableHandle] = []
self._outer_optimizer = outer_optimizer

if bucket_cap_mb is not None:
self.bucket_cap_mb = int(bucket_cap_mb * 1024 * 1024)

self.use_bucketization = use_bucketization
self.should_quantize = should_quantize

self.original_parameters: Dict[str, torch.Tensor] = {}
for name, p in self._model.named_parameters():
for name, p in self._model_fragment.named_parameters():
if isinstance(p, DTensor):
p = extract_local_tensor(p.data)

Expand All @@ -235,20 +210,17 @@ def __init__(
t = t.pin_memory()
self.original_parameters[name] = t

# Need to copy the parameters to the host to be safe if we are on the first step.
self._save_parameters()

def _save_parameters(self) -> None:
def save_parameters(self) -> None:
with torch.no_grad():
# TODO: consider running copy on a separate stream
for name, p in self._model.named_parameters():
for name, p in self._model_fragment.named_parameters():
param_to_local = extract_local_tensor(p.data)
self.original_parameters[name].copy_(param_to_local, non_blocking=True)

def _restore_parameters(self) -> None:
def restore_parameters(self) -> None:
with torch.no_grad():
# TODO: consider running copy on a separate stream
for name, p in self._model.named_parameters():
for name, p in self._model_fragment.named_parameters():
if isinstance(p, DTensor):
# we averaged the local version of the tensor so need to copy it back as a DTensor
p.data.copy_(
Expand All @@ -264,52 +236,25 @@ def _restore_parameters(self) -> None:
else:
p.data.copy_(self.original_parameters[name], non_blocking=False)

def __enter__(self) -> "DiLoCo":
# Add optimizer hook which increments the local step counter and syncs if necessary
self._hooks.append(
self._local_optimizer.register_step_post_hook(self._step_post_hook)
)
return self

def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> bool:
# Handle any cleanup or error handling here
# Clean up hooks
for hook in self._hooks:
hook.remove()
self._hooks.clear()

return False # Propagate exceptions

def _step_post_hook(
self, _optim: optim.Optimizer, _args: Tuple[Any, ...], _kwargs: Dict[str, Any]
) -> None:
"""
This hook is registered on the optimizer and is called after the optimizer step.
"""
self._local_step += 1
if self._local_step >= self._sync_every:
self.sync()

def sync(self) -> None:
"""
Synchronizes and averages the model weights across the manager.
"""
self._local_step += 1

if (self._local_step - self._fragment_sync_offset) % self._sync_every != 0:
return

self._manager.start_quorum()
self._perform_sync()
self._local_step = 0

def _perform_sync(self) -> None:
"""
Overrides the sync method to calculate the pseugradient, average them across the manager group, and
step using the outer optimizer.
"""
# Set the .grad field of each parameter to its pseudogradient
for name, p in self._model.named_parameters():
for name, p in self._model_fragment.named_parameters():
local_param = extract_local_tensor(p.data)
pseudogradient = local_param - self.original_parameters[name].to(p.device)
if isinstance(p, DTensor):
Expand All @@ -319,11 +264,11 @@ def _perform_sync(self) -> None:

self._average_grads()
# Restore the parameters back to the previous state
self._restore_parameters()
self.restore_parameters()
if self._manager.should_commit():
# Use the outer optimizer to update the model parameters
self._outer_optimizer.step()
self._save_parameters()
self.save_parameters()
self._outer_optimizer.zero_grad()

def _average_grads(self) -> None:
Expand All @@ -340,13 +285,17 @@ def _average_grads(self) -> None:
def _allreduce_per_param(self) -> None:
"""Performs allreduce on each gradient tensor separately (original method)."""
works = []
for p in self._model.parameters():
for p in self._model_fragment.parameters():
# Perform allreduce on the pseudogradients
assert p.grad is not None
if isinstance(p, DTensor):
work = self._manager.allreduce(p.grad._local_tensor)
work = self._manager.allreduce(
p.grad._local_tensor, should_quantize=self.should_quantize
)
else:
work = self._manager.allreduce(p.grad)
work = self._manager.allreduce(
p.grad, should_quantize=self.should_quantize
)
works.append(work)

for work in works:
Expand Down Expand Up @@ -388,7 +337,9 @@ def bucketize_and_allreduce(
pack_offset += numel
flat_index += 1

work = self._manager.allreduce(flat_buffer)
work = self._manager.allreduce(
flat_buffer, should_quantize=self.should_quantize
)
work.wait()

for t, pack_offset, numel in bucket_tensors:
Expand All @@ -400,8 +351,148 @@ def _allreduce_bucketized(self) -> None:
"""
Averages gradients using bucketized allreduce with a fixed buffer.
"""
grads = [p.grad for p in self._model.parameters() if p.grad is not None]
grads = [
p.grad for p in self._model_fragment.parameters() if p.grad is not None
]
self.bucketize_and_allreduce(
grads,
bucket_size_bytes=self.bucket_cap_mb,
)


class DiLoCo:
"""
DiLoCo is a subclass of LocalSGD that overrides the synchronization
mechanism to average and synchronize the pseudogradients (delta of the previous global weight and current local weights).

The class implements a more general version of DiLoco, Streaming DiLoCo,
which synchronizes fragments of pseudogradients at different steps.

This algorithm requires a backup copy of the
weights. By default these are stored in CPU memory. If any error occurs
during the DiLoCo step, the step will be discarded and the model
parameters will reset back to the last time DiLoCo synchronized.

DiLoCo paper: https://arxiv.org/pdf/2311.08105
Streaming DiLoCo paper: https://arxiv.org/pdf/2501.18512
"""

def __init__(
self,
manager: Manager,
model_fragments: List[nn.Module],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe from an API / UX perspective we can support nn.Module | List[nn.Module] with the specification that passing in a single nn.Module means whole model.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think it's better to have a smaller api surface and avoid having a special case?

inner_optimizer: optim.Optimizer,
outer_optimizer: optim.Optimizer,
sync_every: int,
backup_device: Optional[torch.device] = None,
pin_memory: bool = True,
use_bucketization: bool = False,
bucket_cap_mb: Optional[int] = None,
should_quantize: bool = False,
fragment_sync_delay: int = 0,
fragment_update_alpha: float = 0.0,
) -> None:
"""
Args:
manager: The manager to use.
model_fragments: The fragments of the model to wrap.
inner_optimizer: The optimizer used for the local parameters every step.
outer_optimizer: The optimizer used for the global parameters updated every "sync_every" steps.
sync_every: How often to update the model weights.
backup_device: The device to store the backup weights on. If None, the backup weights will be on CPU.
pin_memory: Whether to pin the memory for the backup weights (only for CPU device).
should_quantize: Whether to quantize the gradients before allreduce.
fragment_sync_delay: Controls the number of inner steps to wait before blocking on a fragment's
synchronization. This is the "tao" parameter in the Streaming DiLoCo paper.
fragment_update_alpha: Determines how to mix the local and global optimized parameters
"""

if manager._use_async_quorum:
raise ValueError(
"Using DiLoCo require synchronous quorum to be enabled. "
"Ensure that the manager is initialized with use_async_quorum=False"
)

if sync_every < len(model_fragments):
raise ValueError("Only 1 fragment can be syncrhonized at a time")

# TODO: Support multiple fragments
# This requires changing the manager to support `should_commit` for each
# fragment separately.
if len(model_fragments) != 1:
raise ValueError("Multiple fragments are not supported yet")

# TODO: Support `fragment_sync_delay`
if fragment_sync_delay != 0:
raise ValueError("Fragment synchronization delay is not supported yet")

# TODO: Support `fragment_update_alpha`
if fragment_update_alpha != 0.0:
raise ValueError(
"Merging local parameters with global parameters is not supported yet"
)

super().__init__()

self._hooks: List[RemovableHandle] = []

self._local_optimizer = inner_optimizer

self._fragments: List[_StreamingDiLoCoFragment] = [
_StreamingDiLoCoFragment(
manager,
model_fragment,
math.floor((sync_every / len(model_fragments)) * (i + 1)),
inner_optimizer,
# TODO: Support different outer optimizers for each fragment
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh interesting, is that mentioned in the paper?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think they should be different otherwise things like momentum end up being the same for all fragments? Maybe it's not very important though

outer_optimizer,
sync_every,
backup_device,
pin_memory,
use_bucketization,
bucket_cap_mb,
should_quantize,
)
for i, model_fragment in enumerate(model_fragments)
]

# Need to copy the parameters to the host to be safe if we are on the first step.
self._save_parameters()

def _save_parameters(self) -> None:
for fragment in self._fragments:
fragment.save_parameters()

def _restore_parameters(self) -> None:
for fragment in self._fragments:
fragment.restore_parameters()

def __enter__(self) -> "DiLoCo":
# Add optimizer hook which increments the local step counter and syncs if necessary
self._hooks.append(
self._local_optimizer.register_step_post_hook(self._step_post_hook)
)
return self

def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> bool:
# Handle any cleanup or error handling here
# Clean up hooks
for hook in self._hooks:
hook.remove()
self._hooks.clear()

return False # Propagate exceptions

def _step_post_hook(
self, _optim: optim.Optimizer, _args: Tuple[Any, ...], _kwargs: Dict[str, Any]
) -> None:
"""
This hook is registered on the optimizer and is called after the optimizer step.
"""
for fragment in self._fragments:
fragment.sync()
Loading