Skip to content

Commit e4f3f8f

Browse files
hawkinspGoogle-ML-Automation
authored andcommitted
Use libtpu releases rather than libtpu-nightly for jax[tpu].
PiperOrigin-RevId: 688632409
1 parent 1c6b0a9 commit e4f3f8f

File tree

4 files changed

+14
-6
lines changed

4 files changed

+14
-6
lines changed

.github/workflows/cloud-tpu-ci-nightly.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,23 +50,23 @@ jobs:
5050
pip install -U -r build/collect-profile-requirements.txt
5151
- name: Install JAX
5252
run: |
53-
pip uninstall -y jax jaxlib libtpu-nightly
53+
pip uninstall -y jax jaxlib libtpu
5454
if [ "${{ matrix.jaxlib-version }}" == "pypi_latest" ]; then
5555
pip install .[tpu] \
5656
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
5757
5858
elif [ "${{ matrix.jaxlib-version }}" == "nightly" ]; then
5959
pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
60-
pip install --pre libtpu-nightly \
60+
pip install --pre libtpu \
6161
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
6262
pip install requests
6363
6464
elif [ "${{ matrix.jaxlib-version }}" == "nightly+oldest_supported_libtpu" ]; then
6565
pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
66+
# TODO(phawkins): switch to libtpu, when the oldest release we support is a libtpu release.
6667
pip install --pre libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \
6768
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
6869
pip install requests
69-
7070
else
7171
echo "Unknown jaxlib-version: ${{ matrix.jaxlib-version }}"
7272
exit 1

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
2424
* `jax.lax.FftType` was introduced as a public name for the enum of FFT
2525
operations. The semi-public API `jax.lib.xla_client.FftType` has been
2626
deprecated.
27+
* TPU: JAX now installs TPU support from the `libtpu` package rather than
28+
`libtpu-nightly`. For the next few releases JAX will pin an empty version of
29+
`libtpu-nightly` as well as `libtpu` to ease the transition; that dependency
30+
will be removed in Q1 2025.
2731

2832
* Deprecations:
2933
* The semi-public API `jax.lib.xla_client.PaddingType` has been deprecated.

docs/installation.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ pip install -U --pre jax jaxlib -f https://storage.googleapis.com/jax-releases/j
282282
- Google Cloud TPU:
283283

284284
```bash
285-
pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
285+
pip install -U --pre jax jaxlib libtpu requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
286286
```
287287

288288
- NVIDIA GPU (CUDA 12):

setup.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@
2222
_current_jaxlib_version = '0.4.34'
2323
# The following should be updated after each new jaxlib release.
2424
_latest_jaxlib_version_on_pypi = '0.4.34'
25-
_libtpu_version = '0.1.dev20241002'
25+
26+
_libtpu_version = '0.0.2'
27+
_libtpu_nightly_terminal_version = '0.1.dev20241010+nightly.cleanup'
2628

2729
def load_version_module(pkg_path):
2830
spec = importlib.util.spec_from_file_location(
@@ -76,7 +78,9 @@ def load_version_module(pkg_path):
7678
# $ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
7779
'tpu': [
7880
f'jaxlib>={_current_jaxlib_version},<={_jax_version}',
79-
f'libtpu-nightly=={_libtpu_version}',
81+
# TODO(phawkins): remove the libtpu-nightly dependency in Q1 2025.
82+
f'libtpu-nightly=={_libtpu_nightly_terminal_version}',
83+
f'libtpu=={_libtpu_version}',
8084
'requests', # necessary for jax.distributed.initialize
8185
],
8286

0 commit comments

Comments
 (0)