Skip to content

Improve save_model and load_model #19852

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 15, 2024

Conversation

james77777778
Copy link
Contributor

@james77777778 james77777778 commented Jun 14, 2024

Fixes #19793

  • save_model will first open the .h5 (in the same working dir as filepath), then write it to the zipfile at the end.
  • load_model will first extract the .h5, then load it using h5py.File. The weights file will be deleted at the end.

Result:
Using GPT2 from KerasNLP

Branch Function Memory Usage Time Cost
Master save_model 1860.43MiB 3.84s
PR save_model 1384.17MiB (-25.6%) 4.48s (+16.7%)
Master load_model 1377.24MiB 26.93s
PR load_model 1376.43MiB (-0.1%) 2.10s (-92.3%)

The only regression is that save_model might be slightly slower than before.

@codecov-commenter
Copy link

codecov-commenter commented Jun 14, 2024

Codecov Report

Attention: Patch coverage is 78.72340% with 10 lines in your changes missing coverage. Please review.

Project coverage is 78.84%. Comparing base (11be99a) to head (99a8dcc).
Report is 1 commits behind head on master.

Files Patch % Lines
keras/src/saving/saving_lib.py 78.72% 6 Missing and 4 partials ⚠️
Additional details and impacted files
@@           Coverage Diff           @@
##           master   #19852   +/-   ##
=======================================
  Coverage   78.83%   78.84%           
=======================================
  Files         498      498           
  Lines       45864    45887   +23     
  Branches     8450     8455    +5     
=======================================
+ Hits        36159    36181   +22     
  Misses       7999     7999           
- Partials     1706     1707    +1     
Flag Coverage Δ
keras 78.70% <78.72%> (+<0.01%) ⬆️
keras-jax 62.39% <78.72%> (+0.01%) ⬆️
keras-numpy 56.96% <61.70%> (+0.01%) ⬆️
keras-tensorflow 63.68% <78.72%> (+0.01%) ⬆️
keras-torch 62.37% <78.72%> (+0.01%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@Grvzard
Copy link
Contributor

Grvzard commented Jun 14, 2024

a tiny suggestion: the form _VARS_FNAME + ".h5" appears many times. Would it be better to define a new constant like _VARS_FNAME_H5 = _VARS_FNAME + ".h5"?

@james77777778
Copy link
Contributor Author

a tiny suggestion: the form _VARS_FNAME + ".h5" appears many times. Would it be better to define a new constant like _VARS_FNAME_H5 = _VARS_FNAME + ".h5"?

That's cleaner. Thanks for the suggestion.
Updated.

Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix! This is great.

@@ -202,6 +227,8 @@ def _load_model_from_fileobj(fileobj, custom_objects, compile, safe_mode):
weights_store.close()
if asset_store:
asset_store.close()
if weights_file_path:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should put the unlink in a try/except/finally block to make sure it gets executed no matter what.

Copy link
Contributor Author

@james77777778 james77777778 Jun 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added the try/finally blocks.
In _save_model_to_fileobj, if an error is raised, the bad model.weights.h5 won't be written to the zip file.
In _load_model_from_fileobj, if an error is raised, the extracted model.weights.h5 will be deleted no matter what.

Tests have been added to verify these behaviors.

Improve `save_model` and `load_model`

Address comments
Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the contribution!

@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Jun 15, 2024
@fchollet fchollet merged commit 9cf4e94 into keras-team:master Jun 15, 2024
@google-ml-butler google-ml-butler bot removed awaiting review ready to pull Ready to be merged into the codebase kokoro:force-run labels Jun 15, 2024
@james77777778 james77777778 deleted the update-saving branch June 16, 2024 10:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
Status: Merged
Development

Successfully merging this pull request may close these issues.

model.keras format much slower to load
5 participants