Skip to content

[Distributed] Add support for torchchat checkpoint format #1268

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Oct 7, 2024
Merged

Conversation

kwen2501
Copy link
Contributor

@kwen2501 kwen2501 commented Oct 4, 2024

Distributed workflow supports two formats now, controlled by the --chpt-from flag ("hf" or "torchchat").
The difference is whether there is an index file -- impacts loading in pipeline parallel case.

    if chpt_from == "hf":
        # This format stands for: index file + multiple binary files

    elif chpt_from == "torchchat":
        # This format stands for:
        # single binary file, OR
        # multiple binary files without index files.

torchrun --standalone --nproc-per-node 4 dist_run.py llama3 --pp 2 --chpt-from hf
torchrun --standalone --nproc-per-node 4 dist_run.py llama3 --pp 2 --chpt-from torchchat
both run.
But they generate different responses. To be investigated further.

TODO:
we should create real model instead of creating Transformer.

Copy link

pytorch-bot bot commented Oct 4, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchchat/1268

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 0565e8b with merge base b217158 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@kwen2501 kwen2501 requested a review from lessw2020 October 4, 2024 19:51
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Oct 4, 2024
@mikekg
Copy link

mikekg commented Oct 5, 2024

Can you figure out from what's in the location whether it's one or the other format rather than adding another option? (If push comes to shove, you could expect the checkpoint path point to the index file?)

Also, does this compose w/ #1255?

@kwen2501
Copy link
Contributor Author

kwen2501 commented Oct 6, 2024

@mikekg Good idea. Should be easy to implement the detection.

distribution: str,
device: torch.device,
model_config: ModelArgs,
chpt_from: str,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for better clarity, 'chkpt' is a much better abbreviation for checkpoint. chpt is short for 'chapter' which is confusing.

- For HF format, `new_to_old_keymap` is a mapping from the new key to the old
key.
- For torchchat format, `new_to_old_keymap` is None (because FQN conversion
has been doen by torchchat download script).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doen = done :)


updated_states: Set[str] = set()
# This step converts full tensor into DTensor
update_state_dict(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to pass in if this is torchchat or hf to permute or not?

checkpoint_tensor = checkpoint[old_param]
model_tensor = state_dict[param]

if "wq" in param:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to check what you are loading here.
torchchat won't need permuting.

Copy link
Contributor

@lessw2020 lessw2020 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

approving to land.
I think your code here doesn't work b/c you are not checking what type of state dict you are updating. Thus it seems you are going to permute a torchchat checkpoint which doesn't need it.
That may be why: "But they generate different responses. To be investigated further."

Also, please switch the abbreviation for checkpoint to "chkpt' and not "chpt". 'Chpt' to me = 'chapter' and reads confusingly :)

@kwen2501 kwen2501 merged commit dfec431 into main Oct 7, 2024
52 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants