Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

create mininal universal checkpoint info for client state #5526

Closed
wants to merge 2 commits into from

Conversation

xylian86
Copy link
Contributor

This PR solves the Issue-5430.

The PR enables the universal checkpoint feature for other platforms like HuggingFace Trainer without requiring changes to the HuggingFace code. It does this by creating a minimal universal checkpoint info, specifically the version, as a default action for the client state.

@tjruwase tjruwase requested review from samadejacobs, tohtana and lekurile and removed request for mrwyattii May 13, 2024 09:38
@tjruwase
Copy link
Contributor

@xylian86, thanks for this great work. Can you please add convergence curves of an HF model as demo?

@@ -3319,6 +3320,7 @@ def _save_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_parame
ds_config=self.config,
ds_version=version)
state.update(client_state)
inject_universal_info(state)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shoudn't we show a warning when we don't have the necessary info?
It will silently produce an incorrect checkpoint if the checkpoint is loaded for TP or PP.
We can say that the converted checkpoint is only for pure DP scaling.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After I discussed with Tunji, we are considering another approach for this. He will share a new approach. I keep this comment but just disregard it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xylian86, the new approach is to do the injection in the conversion script rather than during saving. The injection should be done into ds_checkpoint before this assertion. Furthermore, the injection should be enabled by command-line argument (disabled by default) so that users are fully aware of what is going on. The command-line arg could be called --inject-missing-state.

@@ -5,7 +5,7 @@

import os
import torch
from .constants import (MODEL_FILE_PREFIX, MODEL_FILE_SUFFIX, OPTIM_FILE_SUFFIX, ZERO_FILE_PREFIX)
from .constants import (MODEL_FILE_PREFIX, MODEL_FILE_SUFFIX, OPTIM_FILE_SUFFIX, ZERO_FILE_PREFIX, UNIVERSAL_CHECKPOINT_INFO, UNIVERSAL_CHECKPOINT_VERSION_KEY, UNIVERSAL_CHECKPOINT_VERSION_VALUE)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI @xylian86 - can you run the precommit formatter on this branch so it will pass our Formatting check?

pre-commit run --all-files

@xylian86
Copy link
Contributor Author

xylian86 commented Jun 3, 2024

Close this PR as I opened a new one at PR-5608 with the new implementation as @tjruwase suggested.

@xylian86 xylian86 closed this Jun 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants