Skip to content

Commit 065fb58

Browse files
authored
Remove PyTorch and Pyro-PPL from workflows (#689)
* Remove PyTorch and Pyro-PPL from workflows - Remove torch, torchvision, torchaudio installation from all workflows - Remove pyro-ppl installation from all workflows - Update step names from 'Install JAX, Numpyro, PyTorch' to 'Install JAX and Numpyro' - Disable build cache in ci.yml for full test run - Keep JAX and NumPyro installations * Fix JAX CUDA installation to use cuda12_pip instead of cuda12-local Changes jax[cuda12-local] to jax[cuda12_pip] to avoid cuDNN version compatibility issues. The cuda12_pip variant includes compatible CUDA and cuDNN libraries bundled with JAX, preventing runtime errors from mismatched local CUDA installations. Fixes cuDNN 9.10.0 backward-compatibility error. * Update RunsOn AMI to latest CUDA/cuDNN image Updates the quantecon_ubuntu2404 AMI from ami-09baf66e396fa7cfd to ami-0edec81935264b6d3 which includes the latest CUDA and cuDNN libraries for improved compatibility with JAX. * Use cuda12-local for JAX installation across all workflows Changes all workflows to use jax[cuda12-local] to leverage the CUDA and cuDNN libraries pre-installed in the new AMI (ami-0edec81935264b6d3). This is faster than cuda12_pip and uses the system libraries. Also removes version pin (==0.6.2) from cache.yml and publish.yml to make all workflows consistent. * Update JAX to use cuda13-local for CUDA 13 support Changes all workflows from cuda12-local to cuda13-local to match the CUDA 13 installation in the new AMI (ami-0edec81935264b6d3). * Update JAX installation to use cuda13 (recommended syntax) Changes from 'jax[cuda13-local]' to 'jax[cuda13]' following the official JAX documentation for CUDA 13 support. Also uses -U flag instead of --upgrade for consistency with JAX docs. * Move nvidia-smi check before JAX installation in ci.yml * Update runs-on.yml with region and use AMI ID directly in ci.yml * Rename image to quantecon_lecture_build to avoid cache issues * Revert image name back to quantecon_ubuntu2404 * Re-enable build cache * Revert instance name back to quantecon_ubuntu2404
1 parent de4fd8b commit 065fb58

File tree

3 files changed

+12
-16
lines changed

3 files changed

+12
-16
lines changed

.github/workflows/cache.yml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,11 @@ jobs:
2020
python-version: "3.13"
2121
environment-file: environment.yml
2222
activate-environment: quantecon
23-
- name: Install JAX, Numpyro, PyTorch
23+
- name: Install JAX and Numpyro
2424
shell: bash -l {0}
2525
run: |
26-
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
27-
pip install --upgrade "jax[cuda12-local]==0.6.2"
28-
pip install numpyro
26+
pip install -U "jax[cuda13]"
27+
pip install numpyro
2928
python scripts/test-jax-install.py
3029
- name: Check nvidia drivers
3130
shell: bash -l {0}

.github/workflows/ci.yml

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,15 @@ jobs:
2828
python-version: "3.13"
2929
environment-file: environment.yml
3030
activate-environment: quantecon
31-
- name: Install JAX, Numpyro, PyTorch
32-
shell: bash -l {0}
33-
run: |
34-
# pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121
35-
# pip install pyro-ppl
36-
pip install "jax[cuda12-local]==0.6.2"
37-
pip install numpyro pyro-ppl
38-
python scripts/test-jax-install.py
3931
- name: Check nvidia Drivers
4032
shell: bash -l {0}
4133
run: nvidia-smi
34+
- name: Install JAX and Numpyro
35+
shell: bash -l {0}
36+
run: |
37+
pip install -U "jax[cuda13]"
38+
pip install numpyro
39+
python scripts/test-jax-install.py
4240
- name: Display Conda Environment Versions
4341
shell: bash -l {0}
4442
run: conda list

.github/workflows/publish.yml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,11 @@ jobs:
1919
python-version: "3.13"
2020
environment-file: environment.yml
2121
activate-environment: quantecon
22-
- name: Install JAX, Numpyro, PyTorch
22+
- name: Install JAX and Numpyro
2323
shell: bash -l {0}
2424
run: |
25-
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
26-
pip install --upgrade "jax[cuda12-local]==0.6.2"
27-
pip install numpyro
25+
pip install -U "jax[cuda13]"
26+
pip install numpyro
2827
python scripts/test-jax-install.py
2928
- name: Check nvidia drivers
3029
shell: bash -l {0}

0 commit comments

Comments
 (0)