Skip to content

Generalize DiLoCo to support Streaming #205

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 3, 2025

Conversation

tushar00jain
Copy link
Contributor

Summary:

  • Add option to perform qunatized allreduce in torchft manager
  • Update user level API's for DiLoCo to also support Streaming DiLoCo -- it takes a list of modules now as input
  • Create a class _StreamingDiLoCoFragment used by DiLoCo to support streaming. Each fragment independently determines its schedule (when to send/sync).
  • Adding support for "alpha" and "tao" parameters from the paper are left as a TODO. Plan to add this in a separate PR.

Test Plan:

$ pytest -vs torchft/local_sgd_integ_test.py
$ pytest -vs torchft/local_sgd_test.py

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label May 30, 2025
@tushar00jain tushar00jain force-pushed the feature/streaming-diloco branch 2 times, most recently from cb34ff1 to 35ca865 Compare May 30, 2025 23:38
@tushar00jain tushar00jain force-pushed the feature/streaming-diloco branch from 35ca865 to ab00d7d Compare May 31, 2025 18:12
Summary:
- Add option to perform qunatized allreduce in torchft manager
- Update user level API's for DiLoCo to also support Streaming DiLoCo -- it takes a list of modules now as input
- Create a class `_StreamingDiLoCoFragment` used by DiLoCo to support streaming. Each fragment independently determines its schedule (when to send/sync).
- Adding support for "alpha" and "tao" parameters from the paper are left as a TODO. Plan to add this in a separate PR.

Test Plan:
```
$ pytest -vs torchft/local_sgd_integ_test.py
$ pytest -vs torchft/local_sgd_test.py
```
@tushar00jain tushar00jain force-pushed the feature/streaming-diloco branch from ab00d7d to 0f07f2d Compare May 31, 2025 20:07
Copy link
Member

@H-Huang H-Huang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!! Thanks for getting this out so quickly, just a few small comments but can also be addressed in follow up PRs

@@ -267,7 +291,9 @@ def shutdown(self, wait: bool = True) -> None:
self._manager.shutdown()
self._executor.shutdown(wait=wait)

def allreduce(self, tensor: torch.Tensor) -> torch.futures.Future[torch.Tensor]:
def allreduce(
self, tensor: torch.Tensor, should_quantize: bool = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would also update some of the tests in manager_test.py to also include using the should_quantize flag. Can be done in a follow up PR

except ImportError:
from torch import cuda

def allreduce_quantized(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: is this stub necessary? Can't we just have a constant like TRITON_AVAILABLE and then check that in the if statement in the implementation

Copy link
Contributor Author

@tushar00jain tushar00jain Jun 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Less configuration options 🥲 let's say someone changes platforms, then they can start using triton automatically without having to modify the constant. Also avoids having us to configure CI properly for different platforms.

def __init__(
self,
manager: Manager,
model_fragments: List[nn.Module],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe from an API / UX perspective we can support nn.Module | List[nn.Module] with the specification that passing in a single nn.Module means whole model.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think it's better to have a smaller api surface and avoid having a special case?

model_fragment,
math.floor((sync_every / len(model_fragments)) * (i + 1)),
inner_optimizer,
# TODO: Support different outer optimizers for each fragment
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh interesting, is that mentioned in the paper?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think they should be different otherwise things like momentum end up being the same for all fragments? Maybe it's not very important though

@tushar00jain tushar00jain force-pushed the feature/streaming-diloco branch from 0f07f2d to 600864f Compare June 3, 2025 17:32
@tushar00jain tushar00jain merged commit 2ac219d into pytorch:main Jun 3, 2025
8 checks passed
@tushar00jain tushar00jain deleted the feature/streaming-diloco branch June 3, 2025 20:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants