Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

⚡️ Overall improvements, bug fixes, more unit tests, and dm-haiku compatibility tested #6

Merged
merged 18 commits into from
Dec 26, 2022

Conversation

alvarobartt
Copy link
Owner

@alvarobartt alvarobartt commented Dec 26, 2022

✨ Features

  • Split flax.py into save.py, load.py, and utils.py for readability
    • save.py contains serialize
    • load.py contains deserialize
    • utils.py contains both flatten_dict and unflatten_dict
  • Add freeze_dict param to unflatten_dict to either convert it to FrozenDict or keep it as a Dict (used for flax)
  • Update unit tests with pytest to cover every safejax function
  • Test dm-haiku model param serialization over haiku.nets.ResNet50
  • Add more examples/ for both flax and dm-haiku

🐛 Bug Fixes

  • Fix bug while unflattening dictionaries in unflatten_dict due to a variable being overwritten

🧪 Tests

  • Did you implement unit tests if required?

If the above checkbox is checked, describe how you unit-tested it.

  • Add some assertions to make sure both safejax.utils.flatten_dict and safejax.utils.unflatten_dict work as expected to avoid bug mentioned above with unflatten_dict
  • Add some more unit tests for safejax.load and safejax.save due to the recent split of both files

@alvarobartt alvarobartt self-assigned this Dec 26, 2022
@alvarobartt alvarobartt merged commit aaee170 into main Dec 26, 2022
@alvarobartt alvarobartt deleted the extend-usage branch December 26, 2022 10:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant