Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update jax, brax and flax versions (fixes the jax.tree_util warnings) #76

Merged
merged 15 commits into from
Oct 6, 2022

Conversation

Lookatator
Copy link
Member

@Lookatator Lookatator commented Aug 25, 2022

Related to issue #74

  • all jax tree-based functions are now imported from the tree_util module
  • updates the dependencies in requirements.txt and setup.py
  • Update the version of Brax in the requirements and in the setup.py

This PR also adds the wrapper CompletedEvalWrapper that used to be in Brax, but has been removed in the most recent versions.

In summary, this wrapper used to be present in Brax, and we rely on some elements of it in 4 algorithms (I think those are: SAC, DIAYN, DADS and TD3), and their tests where all failing for that reason.

So what I did is copying the old wrapper of Brax here, and now all the tests pass.

This is a provisional solution to make it work in the same way as before.
Maybe we can improve and adapt our own code to the new structure of QDax, but I think that is for another time/PR.

@Lookatator Lookatator added the enhancement New feature or request label Aug 25, 2022
@Lookatator Lookatator marked this pull request as ready for review August 25, 2022 09:44
@Lookatator Lookatator linked an issue Aug 25, 2022 that may be closed by this pull request
@limbryan
Copy link
Collaborator

We should also update flax and chex version to fix the warnings completely - Flax (0.6.0) and chex (0.1.4) versions have these changes updated.

@Lookatator
Copy link
Member Author

We should also update flax and chex version to fix the warnings completely - Flax (0.6.0) and chex (0.1.4) versions have these changes updated.

That was included in the PR.

However, I've just changed it a bit. Because:

  • flax == 0.6.0 requires jax >= 0.3.16
  • jax >= 0.3.16 does not have tree_multimap
  • all versions of brax do not have tree_multimap

QDax/setup.py

Lines 27 to 29 in 8aa7235

# if flax>=0.6, then requires jax==0.3.16,
# which is incompatible with all versions of brax
"flax>=0.5.0,<0.6",

@codecov-commenter
Copy link

codecov-commenter commented Aug 25, 2022

Codecov Report

Merging #76 (92fade5) into main (67aa13d) will decrease coverage by 0.00%.
The diff coverage is 96.87%.

❗ Current head 92fade5 differs from pull request most recent head 8168865. Consider uploading reports for the commit 8168865 to get more accurate results

@@            Coverage Diff             @@
##             main      #76      +/-   ##
==========================================
- Coverage   89.49%   89.49%   -0.01%     
==========================================
  Files          66       66              
  Lines        3864     3863       -1     
==========================================
- Hits         3458     3457       -1     
  Misses        406      406              
Impacted Files Coverage Δ
qdax/core/neuroevolution/mdp_utils.py 90.62% <50.00%> (ø)
qdax/core/containers/mome_repertoire.py 98.11% <91.66%> (ø)
qdax/baselines/dads.py 97.00% <100.00%> (ø)
qdax/baselines/diayn.py 93.07% <100.00%> (ø)
qdax/baselines/sac.py 93.57% <100.00%> (ø)
qdax/baselines/td3.py 100.00% <100.00%> (ø)
qdax/core/containers/ga_repertoire.py 80.39% <100.00%> (ø)
qdax/core/containers/mapelites_repertoire.py 84.78% <100.00%> (ø)
qdax/core/containers/nsga2_repertoire.py 98.52% <100.00%> (ø)
qdax/core/containers/spea2_repertoire.py 100.00% <100.00%> (ø)
... and 6 more

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

@felixchalumeau
Copy link
Collaborator

felixchalumeau commented Sep 1, 2022

Hey 👋

Have you made a few double checks to be sure that those updates won't affect performances?

@felixchalumeau
Copy link
Collaborator

Hi guys,

I reviewed the PR and am happy with it except for the CompletedEvalWrapper.
I want to ask you more questions about it and also, does it has to go with this PR in term of compatibility or can it be introduced in a separate pull request?

@felixchalumeau felixchalumeau changed the title Fix: jax tree utils warnings Update jax, brax and flax versions (fixes the jax.tree_util warnings) Oct 3, 2022
@Lookatator Lookatator merged commit 2fb5619 into main Oct 6, 2022
@limbryan limbryan deleted the fix/jax-tree-utils-warnings branch November 30, 2022 15:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

fix the Jax FutureWarnings
4 participants