-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add DiLoCo #76
Add DiLoCo #76
Conversation
ghstack-source-id: 68e071e88b5b238d137e0ecdaa33d97b79370b22 Pull Request resolved: #76
ghstack-source-id: 2153244514c7ff795ec590804d208c9d9cd2b4ed Pull Request resolved: #76
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code is looking good :) There's a couple of edge cases I want to discuss more
torchft/local_sgd.py
Outdated
""" | ||
ensure model has the same weights | ||
""" | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's this for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This isn't used, but it was a reminder for me to figure out how to make sure the model across all the replica groups are the same. For LocalSGD this doesn't matter as much since each group will have the same model after step 1.
For DiLoCo we need to ensure each replica starts with the same model weights otherwise each replica will be different since we are only exchanging pseudogradients.
In DDP, we broadcast from rank 0 to the entire group, maybe we should do the same? If we don't do this explicitly for the user then maybe we at least should validate that the models have the same weights? What are your thoughts there?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
one middle ground could be to add a public API function to the DiLoCo object so users can opt into the extra validation step, but not get it by default. I'm not sure it is defensible to insist on doing a broadcast of parameters, given models have gotten a lot bigger since the time DDP was developed.
e.g.
with DiLoCo(m ...) as d:
d.validate_model_init()
for train steps...
Otoh as I write this, it occurs to me that this is really not a diloco specific feature. It might be even better to just provide some kind of distributed util function that users can pass a local model obj to and can perform some cross-rank validation of parameter initialization.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the distributed utility for validation would be helpful. I got around this requirement in my integration test by initializing in the main process then passing that state_dict()
to the threads running the replica groups. In a real environment this wouldn't be possible so it would make sense to load from a checkpoint. We dont have a method that loads from checkpoint in the manager which is also blocking so I can implement that in a follow up PR
|
||
def state_dict() -> Dict[str, Dict[str, object]]: | ||
return { | ||
"model": m.state_dict(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One concern I have with this implementation is that we're transferring the model state_dict not the backed up parameters. If the local training is too fast we might end up with a corrupted copy of the weights in async mode.
There's two cases I'm concerned about:
1. async mode
- trainers 0,1 call start_quorum and get a new quorum
- trainer 1 is behind so fetches the checkpoint from trainer 0
- trainer 0 completes a step so it's weights have changed
- trainer 1 finishes fetching the checkpoint and the weights are now inconsistent since it was actively being changed while fetching
We can solve this one by adding an assertion that the manager is in sync mode -- which is probably the behavior users want regardless.
2. should_commit==False
- trainers 0,1 call start_quorum and get a new quorum
- trainer 1 is behind so fetches the checkpoint from trainer 0
- should_commit() == False
- trainer 0,1 fallback to the backup parameters
- trainer 1 never saved the new checkpoint parameters and thus rolls back to untrained weights
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good points. Yeah this implementation hasn't really considered checkpointing very well. Your solution for number 1 makes sense. Have two follow up questions:
- How should
CheckpointServer
interact withLocalSGD
andDiLoCo
, should we be saving the backup parameters to the checkpoint server after each sync? - For number 2, when do we have
should_commit == False
? Is this just for failures?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- With the current design for LocalSGD/DiLoCo we always transfer the backup parameters (i.e. the weights right after the outer LocalSGD/DiLoCo averaging step). Effectively we treat the primary weights as the "intermediate state" equivalent to the gradients from normal training and we only commit them at the end of the N inner steps/1 outer step. With sync mode this is less of a concern since we will always heal before training so the backup and primary weights will always be the same
- If any error occurs during the cross replica communication we'll end up with should_commit == False, this most commonly will happen if a worker in a different replica group crashes and then NCCL/Gloo times out/errors. Should_commit checks whether any error has happened on any of the ranks in a replica group. This is intended to prevent partial (say 1/10 ranks had an error) and corrupted (a rank had an error in communication and tensor is in a partially updated state) model updates
In terms of automatically backing things up, I think we may be able to use register_load_state_dict_post_hook
to detect when a state_dict is loaded to the model and trigger a backup when that happens.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As discussed offline, after this PR lands we still have issue 2. We should follow this up with a PR to support shrink_only=True quorum so we never discard any computation and drop the rollback behavior to mimic how people actually use LocalSGD in practice
cc @Jackmin801
ghstack-source-id: 030098f8023ba12ae756b8b26bcb740ec9151151 Pull Request resolved: #76
ghstack-source-id: 588c13d633fc69574f8a87e46bbe8ae4069d4a3c Pull Request resolved: #76
haven't read the code yet, but having What is the reason that we need start-quorum to be connected to the model via pre-forward? doesn't it just need to happen once per inner step? Again, with PP, there are more than one forward per inner-optimizer step. So we need to at least ensure we aren't starting a quorum too often. |
torchft/local_sgd.py
Outdated
This should be called before the optimizer step. | ||
|
||
This will start the quorum and save the parameters if this is the first step. | ||
Start the quorum before each module forward. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need a new quorum on every module forward?
we don't sync after every module forward/backward, do we? IIUC we sync every 'sync_every'. So do we only need a new quorum if we just finished a sync step?
I wonder whether this would be clearer if we inverted things
diloco = DiLoCo(...)
for steps:
# every time we enter/exit this context, it signals a quorum start. not sure if that is helpful, or maybe i'm misunderstanding things.
with diloco():
model()...
optimizer.step()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i guess riffing on that, the helpers for DiLoCo and LocalSGD seem fine to me (and convenient) but I think even more important in a way is providing a good base class people can build other recipes on top of. The key pieces seem to be
- expressing how often to sync
- expressing the code to run inside the quorum
- expressing the code to run across quorums
quorum = QuorumBuilder(outer_optimizer, ...)
for outer_steps:
for inner_steps:
with quorum:
model()...
inner_optimizer.step()
if quorum.should_commit:
average()
outer_optimizer.step()
I am probably misusing some terms and apis here, and i'm not sure having separate loops is desirable, but i am wondering if we have building blocks that let users write both diloco and local-sgd in about that many lines of code themselves?
If we do, then building the DiLoCo/LocalSGD helpers on top seems nice too
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are good points! I don't actually think we need the quorum for each forward step so i removed the pre-forward hook. The only downside to not computing quorum each step is the delayed failure detection. This would be a problem if dropping some replica groups was a problem for convergence, but according to the DiLoCo paper they only a small drop in perplexity compared to a perfect communication experiment.
I like the more general building block pieces you mentioned. Currently Manager
acts as the component that manages the quorums and what is happening across replica groups so in that same vein i think it makes sense to enhance it to also run custom code. In doing so, I think we can achieve something similar to what you wrote above and gives users better flexibility in writing their own stuff.
@H-Huang just fyi, ghstack doesn't work with TorchFT. You won't be able to land it later. |
ghstack-source-id: 357333e8601958ad86a2cbff78e56ea2cbe447c2 Pull Request resolved: #76
ghstack-source-id: 357333e8601958ad86a2cbff78e56ea2cbe447c2 Pull Request resolved: #76
5fd4430
to
a65a5ff
Compare
3dc8b87
to
866873a
Compare
ghstack-source-id: 357333e8601958ad86a2cbff78e56ea2cbe447c2 Pull Request resolved: #76
a65a5ff
to
67f4caf
Compare
Why doesn't ghstack work? I thought it didn't need special enablement. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
# Set the .grad field of each parameter to its pseudogradient | ||
for name, p in self._model.named_parameters(): | ||
assert name in self._backup_parameters | ||
pseudogradient = p.data - self._backup_parameters[name] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to check which device the backup_parameters are on? We need a .to call here if they're on CPU right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yeah good point, I will change up the backup_parameters and the load_state_dict in a follow up PR
torch.testing.assert_close( | ||
state_dict["outer_optim"], | ||
state_dicts[0][str_step]["outer_optim"], | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we be checking "model" params as well? or are they equivalent to backup_params at the end of the step?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct, they will be equivalent at the end of the step
ghstack-source-id: 357333e8601958ad86a2cbff78e56ea2cbe447c2 Pull Request resolved: #76
67f4caf
to
3c12ce9
Compare
ghstack-source-id: 357333e8601958ad86a2cbff78e56ea2cbe447c2 Pull Request resolved: #76
Stack from ghstack (oldest at bottom):
API Usage
Changes
LocalSGD
to be a context manager rather than ann.Module
wrapper.discussion doc: https://docs.google.com/document/d/11c5JwQpSzilrDvK-vNsgQhpXAihbMn-hTRC8y3LiGqY/edit?tab=t.0#heading=h.izo4yi6jz4mk