-
Notifications
You must be signed in to change notification settings - Fork 555
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Serialization] Add is_main_process
argument to save_torch_state_dict()
#2648
[Serialization] Add is_main_process
argument to save_torch_state_dict()
#2648
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@@ -472,3 +472,27 @@ def test_save_torch_state_dict_delete_existing_files( | |||
assert (tmp_path / "pytorch_model-00001-of-00003.bin").is_file() | |||
assert (tmp_path / "pytorch_model-00002-of-00003.bin").is_file() | |||
assert (tmp_path / "pytorch_model-00003-of-00003.bin").is_file() | |||
|
|||
|
|||
def test_save_torch_state_dict_not_main_process( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice test!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good!
Once this is released in huggingface_hub==0.27.0, PRs will be opened to update accelerate's save_model(), transformers' save_pretrained() and diffusers' save_pretrained() to use save_torch_state_dict() directly.
Super exciting!
except Exception as e: | ||
logger.warning(f"Error when trying to remove existing '{filename}' from folder: {e}. Continuing...") | ||
# Only main process should clean up existing files to avoid race conditions in distributed environment | ||
if is_main_process: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice ! Thanks for adding this !
Thanks for the reviews! |
This small PR adds
is_main_process
parameter tosave_torch_state_dict()
to prevent race conditions during distributed environment. This aligns with accelerate's, transformers' and diffusers' implementations and will enable standardization of model saving across these libraries. See #2314 and this internal slack message for more context.Once this is released in huggingface_hub==0.27.0, PRs will be opened to update accelerate's save_model(), transformers' save_pretrained() and diffusers' save_pretrained() to use
save_torch_state_dict()
directly.Main changes:
(Following existing implementations in accelerate, transformers and diffusers)
is_main_process=True
to avoid race conditions during distributed environment.is_main_process=True
as default parameter.cc @muellerzr, @SunMarc and @sayakpaul for visibility.