-
Notifications
You must be signed in to change notification settings - Fork 250
[Distributed] Add support for torchchat checkpoint format #1268
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchchat/1268
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 0565e8b with merge base b217158 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Can you figure out from what's in the location whether it's one or the other format rather than adding another option? (If push comes to shove, you could expect the checkpoint path point to the index file?) Also, does this compose w/ #1255? |
@mikekg Good idea. Should be easy to implement the detection. |
distribution: str, | ||
device: torch.device, | ||
model_config: ModelArgs, | ||
chpt_from: str, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for better clarity, 'chkpt' is a much better abbreviation for checkpoint. chpt is short for 'chapter' which is confusing.
- For HF format, `new_to_old_keymap` is a mapping from the new key to the old | ||
key. | ||
- For torchchat format, `new_to_old_keymap` is None (because FQN conversion | ||
has been doen by torchchat download script). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
doen = done :)
|
||
updated_states: Set[str] = set() | ||
# This step converts full tensor into DTensor | ||
update_state_dict( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need to pass in if this is torchchat or hf to permute or not?
checkpoint_tensor = checkpoint[old_param] | ||
model_tensor = state_dict[param] | ||
|
||
if "wq" in param: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need to check what you are loading here.
torchchat won't need permuting.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
approving to land.
I think your code here doesn't work b/c you are not checking what type of state dict you are updating. Thus it seems you are going to permute a torchchat checkpoint which doesn't need it.
That may be why: "But they generate different responses. To be investigated further."
Also, please switch the abbreviation for checkpoint to "chkpt' and not "chpt". 'Chpt' to me = 'chapter' and reads confusingly :)
Distributed workflow supports two formats now, controlled by the
--chpt-from
flag ("hf" or "torchchat").The difference is whether there is an index file -- impacts loading in pipeline parallel case.
torchrun --standalone --nproc-per-node 4 dist_run.py llama3 --pp 2 --chpt-from hf
torchrun --standalone --nproc-per-node 4 dist_run.py llama3 --pp 2 --chpt-from torchchat
both run.
But they generate different responses. To be investigated further.
TODO:
we should create real model instead of creating Transformer.