Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DiLoCo #76

Merged
merged 1 commit into from
Jan 30, 2025
Merged

Add DiLoCo #76

merged 1 commit into from
Jan 30, 2025

Conversation

H-Huang
Copy link
Member

@H-Huang H-Huang commented Jan 21, 2025

Stack from ghstack (oldest at bottom):

API Usage

# LocalSGD example
model = SimpleModel()
optimizer = optim.SGD(model.parameters())
manager = create_autospec(Manager)
with LocalSGD(manager, model, optimizer, sync_every=2):
    for inp, label in dataloader:
        loss = model(inp).mean()
        loss.backward()
        optimizer.step()
        
# DiLoCo example
model = SimpleModel()
inner_optimizer = torch.optim.AdamW(
    m.parameters(), lr=4e-4, weight_decay=0.1, betas=(0.9, 0.95)
)
outer_optimizer = torch.optim.SGD(
    m.parameters(), lr=0.7, momentum=0.9, nesterov=True
)
manager = create_autospec(Manager)
with DiLoCo(manager, model, inner_optimizer, outer_optimizer, sync_every=2):
    for inp, label in dataloader:
        loss = model(inp).mean()
        loss.backward()
        inner_optimizer.step()
        # outer_optimizer is actually used every 'sync_every' steps but this is hidden from the user

Changes

  • Updated LocalSGD to be a context manager rather than a nn.Module wrapper.
  • Added DiLoCo. This is a subclass of LocalSGD since a lot of code is shared
  • Added test to validate global models and outer optimizers and same at every step

discussion doc: https://docs.google.com/document/d/11c5JwQpSzilrDvK-vNsgQhpXAihbMn-hTRC8y3LiGqY/edit?tab=t.0#heading=h.izo4yi6jz4mk

H-Huang added a commit that referenced this pull request Jan 21, 2025
ghstack-source-id: 68e071e88b5b238d137e0ecdaa33d97b79370b22
Pull Request resolved: #76
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jan 21, 2025
@H-Huang H-Huang marked this pull request as draft January 21, 2025 22:18
H-Huang added a commit that referenced this pull request Jan 21, 2025
ghstack-source-id: 2153244514c7ff795ec590804d208c9d9cd2b4ed
Pull Request resolved: #76
@H-Huang H-Huang marked this pull request as ready for review January 22, 2025 19:56
@H-Huang H-Huang requested review from d4l3k, wconstab and c-p-i-o January 22, 2025 19:57
Copy link
Member

@d4l3k d4l3k left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code is looking good :) There's a couple of edge cases I want to discuss more

torchft/local_sgd.py Outdated Show resolved Hide resolved
torchft/local_sgd.py Outdated Show resolved Hide resolved
"""
ensure model has the same weights
"""
pass
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's this for?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't used, but it was a reminder for me to figure out how to make sure the model across all the replica groups are the same. For LocalSGD this doesn't matter as much since each group will have the same model after step 1.

For DiLoCo we need to ensure each replica starts with the same model weights otherwise each replica will be different since we are only exchanging pseudogradients.

In DDP, we broadcast from rank 0 to the entire group, maybe we should do the same? If we don't do this explicitly for the user then maybe we at least should validate that the models have the same weights? What are your thoughts there?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one middle ground could be to add a public API function to the DiLoCo object so users can opt into the extra validation step, but not get it by default. I'm not sure it is defensible to insist on doing a broadcast of parameters, given models have gotten a lot bigger since the time DDP was developed.

e.g.

with DiLoCo(m ...) as d:
    d.validate_model_init()
    for train steps...

Otoh as I write this, it occurs to me that this is really not a diloco specific feature. It might be even better to just provide some kind of distributed util function that users can pass a local model obj to and can perform some cross-rank validation of parameter initialization.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the distributed utility for validation would be helpful. I got around this requirement in my integration test by initializing in the main process then passing that state_dict() to the threads running the replica groups. In a real environment this wouldn't be possible so it would make sense to load from a checkpoint. We dont have a method that loads from checkpoint in the manager which is also blocking so I can implement that in a follow up PR

torchft/local_sgd.py Show resolved Hide resolved
torchft/local_sgd.py Show resolved Hide resolved

def state_dict() -> Dict[str, Dict[str, object]]:
return {
"model": m.state_dict(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One concern I have with this implementation is that we're transferring the model state_dict not the backed up parameters. If the local training is too fast we might end up with a corrupted copy of the weights in async mode.

There's two cases I'm concerned about:

1. async mode

  1. trainers 0,1 call start_quorum and get a new quorum
  2. trainer 1 is behind so fetches the checkpoint from trainer 0
  3. trainer 0 completes a step so it's weights have changed
  4. trainer 1 finishes fetching the checkpoint and the weights are now inconsistent since it was actively being changed while fetching

We can solve this one by adding an assertion that the manager is in sync mode -- which is probably the behavior users want regardless.

2. should_commit==False

  1. trainers 0,1 call start_quorum and get a new quorum
  2. trainer 1 is behind so fetches the checkpoint from trainer 0
  3. should_commit() == False
  4. trainer 0,1 fallback to the backup parameters
  5. trainer 1 never saved the new checkpoint parameters and thus rolls back to untrained weights

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good points. Yeah this implementation hasn't really considered checkpointing very well. Your solution for number 1 makes sense. Have two follow up questions:

  1. How should CheckpointServer interact with LocalSGD and DiLoCo, should we be saving the backup parameters to the checkpoint server after each sync?
  2. For number 2, when do we have should_commit == False? Is this just for failures?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@H-Huang

  1. With the current design for LocalSGD/DiLoCo we always transfer the backup parameters (i.e. the weights right after the outer LocalSGD/DiLoCo averaging step). Effectively we treat the primary weights as the "intermediate state" equivalent to the gradients from normal training and we only commit them at the end of the N inner steps/1 outer step. With sync mode this is less of a concern since we will always heal before training so the backup and primary weights will always be the same
  2. If any error occurs during the cross replica communication we'll end up with should_commit == False, this most commonly will happen if a worker in a different replica group crashes and then NCCL/Gloo times out/errors. Should_commit checks whether any error has happened on any of the ranks in a replica group. This is intended to prevent partial (say 1/10 ranks had an error) and corrupted (a rank had an error in communication and tensor is in a partially updated state) model updates

In terms of automatically backing things up, I think we may be able to use register_load_state_dict_post_hook to detect when a state_dict is loaded to the model and trigger a backup when that happens.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed offline, after this PR lands we still have issue 2. We should follow this up with a PR to support shrink_only=True quorum so we never discard any computation and drop the rollback behavior to mimic how people actually use LocalSGD in practice

cc @Jackmin801

H-Huang added a commit that referenced this pull request Jan 28, 2025
ghstack-source-id: 030098f8023ba12ae756b8b26bcb740ec9151151
Pull Request resolved: #76
H-Huang added a commit that referenced this pull request Jan 28, 2025
ghstack-source-id: 588c13d633fc69574f8a87e46bbe8ae4069d4a3c
Pull Request resolved: #76
@wconstab
Copy link

Updated LocalSGD to be a context manager rather than a nn.Module wrapper. This required adding a pre_forward_hook to the model start the quorum

haven't read the code yet, but having model as an argument to DiLoCo obj made me think about how this would work with pipeline parallelism, where there are multiple model parts.

What is the reason that we need start-quorum to be connected to the model via pre-forward? doesn't it just need to happen once per inner step? Again, with PP, there are more than one forward per inner-optimizer step. So we need to at least ensure we aren't starting a quorum too often.

This should be called before the optimizer step.

This will start the quorum and save the parameters if this is the first step.
Start the quorum before each module forward.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need a new quorum on every module forward?

we don't sync after every module forward/backward, do we? IIUC we sync every 'sync_every'. So do we only need a new quorum if we just finished a sync step?

I wonder whether this would be clearer if we inverted things

diloco = DiLoCo(...)

for steps:
   # every time we enter/exit this context, it signals a quorum start.  not sure if that is helpful, or maybe i'm misunderstanding things.
   with diloco():
        model()...
        optimizer.step()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i guess riffing on that, the helpers for DiLoCo and LocalSGD seem fine to me (and convenient) but I think even more important in a way is providing a good base class people can build other recipes on top of. The key pieces seem to be

  • expressing how often to sync
  • expressing the code to run inside the quorum
  • expressing the code to run across quorums
quorum = QuorumBuilder(outer_optimizer, ...)

for outer_steps:
   for inner_steps:
      with quorum:
          model()...
          inner_optimizer.step()
   if quorum.should_commit:
      average()
      outer_optimizer.step()

I am probably misusing some terms and apis here, and i'm not sure having separate loops is desirable, but i am wondering if we have building blocks that let users write both diloco and local-sgd in about that many lines of code themselves?

If we do, then building the DiLoCo/LocalSGD helpers on top seems nice too

Copy link
Member Author

@H-Huang H-Huang Jan 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are good points! I don't actually think we need the quorum for each forward step so i removed the pre-forward hook. The only downside to not computing quorum each step is the delayed failure detection. This would be a problem if dropping some replica groups was a problem for convergence, but according to the DiLoCo paper they only a small drop in perplexity compared to a perfect communication experiment.

I like the more general building block pieces you mentioned. Currently Manager acts as the component that manages the quorums and what is happening across replica groups so in that same vein i think it makes sense to enhance it to also run custom code. In doing so, I think we can achieve something similar to what you wrote above and gives users better flexibility in writing their own stuff.

@fegin
Copy link
Contributor

fegin commented Jan 29, 2025

@H-Huang just fyi, ghstack doesn't work with TorchFT. You won't be able to land it later.

torchft/local_sgd.py Outdated Show resolved Hide resolved
H-Huang added a commit that referenced this pull request Jan 29, 2025
ghstack-source-id: 357333e8601958ad86a2cbff78e56ea2cbe447c2
Pull Request resolved: #76
@H-Huang H-Huang changed the title [WIP] Add DiLoCo Add DiLoCo Jan 29, 2025
H-Huang added a commit that referenced this pull request Jan 29, 2025
ghstack-source-id: 357333e8601958ad86a2cbff78e56ea2cbe447c2
Pull Request resolved: #76
H-Huang added a commit that referenced this pull request Jan 30, 2025
ghstack-source-id: 357333e8601958ad86a2cbff78e56ea2cbe447c2
Pull Request resolved: #76
@wconstab
Copy link

Why doesn't ghstack work? I thought it didn't need special enablement.
Did you try using cmdline 'ghstack land'? @fegin

@H-Huang H-Huang requested a review from d4l3k January 30, 2025 16:44
Copy link
Member

@d4l3k d4l3k left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

torchft/local_sgd.py Outdated Show resolved Hide resolved
# Set the .grad field of each parameter to its pseudogradient
for name, p in self._model.named_parameters():
assert name in self._backup_parameters
pseudogradient = p.data - self._backup_parameters[name]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to check which device the backup_parameters are on? We need a .to call here if they're on CPU right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah good point, I will change up the backup_parameters and the load_state_dict in a follow up PR

torch.testing.assert_close(
state_dict["outer_optim"],
state_dicts[0][str_step]["outer_optim"],
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we be checking "model" params as well? or are they equivalent to backup_params at the end of the step?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct, they will be equivalent at the end of the step

ghstack-source-id: 357333e8601958ad86a2cbff78e56ea2cbe447c2
Pull Request resolved: #76
@H-Huang H-Huang merged commit 3c12ce9 into gh/H-Huang/1/base Jan 30, 2025
6 checks passed
@H-Huang H-Huang mentioned this pull request Jan 30, 2025
H-Huang added a commit that referenced this pull request Jan 31, 2025
ghstack-source-id: 357333e8601958ad86a2cbff78e56ea2cbe447c2
Pull Request resolved: #76
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants