Skip to content

Commit 64ac43a

Browse files
authored
[CI] Bump Flax and Jaxlib versions to fix Jaxlib install error (#15421)
Bump Flax and Jax versions to fix install error The Flax dependency Orbax (v0.1.8) has deprecated being able to install Orbax as a standalone package. Flax v0.6.8 attempts to install Orbax as a standalone package and raises an error about doing so. Going forward, the package orbax-checkpoint should be installed instead. Flax v0.6.8 does not recognize this and attempts to install Orbax instead of orbax-checkpoint and the installation fails. In order to resolve Jax installation issues, bumping the version of Flax to be at least 0.6.9, which resolves the problem. Flax >= 0.6.9 does not pin the version of orbax-checkpoint that it installs and the latest version requires Jax >= 0.4.9 to be installed so the two must be updated together.
1 parent 4a55820 commit 64ac43a

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

docker/install/ubuntu_install_jax.sh

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,13 @@ set -o pipefail
2323
# Install jax and jaxlib
2424
if [ "$1" == "cuda" ]; then
2525
pip3 install --upgrade \
26-
jaxlib==0.4.7 \
27-
"jax[cuda11_pip]==0.4.7" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
26+
jaxlib~=0.4.9 \
27+
"jax[cuda11_pip]~=0.4.9" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
2828
else
2929
pip3 install --upgrade \
30-
jaxlib==0.4.7 \
31-
"jax[cpu]==0.4.7"
30+
jaxlib~=0.4.9 \
31+
"jax[cpu]~=0.4.9"
3232
fi
3333

3434
# Install flax
35-
pip3 install flax==0.6.8
35+
pip3 install flax~=0.6.9

0 commit comments

Comments
 (0)