Closed
Description
Description
Installing jax[cuda12]==0.6.1
leads to a RuntimeError because the cuSPARSE library cannot be found, causing JAX to fall back to CPU execution despite the presence of CUDA-enabled jaxlib components.
Here is how to reproduce:
$ uv venv
Using CPython 3.13.3
Creating virtual environment at: .venv
Activate with: source .venv/bin/activate
$ source .venv/bin/activate
(.venv) $ uv pip install -U "jax[cuda12]"
Using Python 3.13.3 environment at: .venv
Resolved 19 packages in 304ms
Prepared 19 packages in 37ms
Installed 19 packages in 6.00s
+ jax==0.6.1
+ jax-cuda12-pjrt==0.6.1
+ jax-cuda12-plugin==0.6.1
+ jaxlib==0.6.1
+ ml-dtypes==0.5.1
+ numpy==2.2.6
+ nvidia-cublas-cu12==12.8.4.1
+ nvidia-cuda-cupti-cu12==12.9.19
+ nvidia-cuda-nvcc-cu12==12.9.41
+ nvidia-cuda-runtime-cu12==12.9.37
+ nvidia-cudnn-cu12==9.10.1.4
+ nvidia-cufft-cu12==11.4.0.6
+ nvidia-cusolver-cu12==11.7.4.40
+ nvidia-cusparse-cu12==12.5.9.5
+ nvidia-nccl-cu12==2.26.5
+ nvidia-nvjitlink-cu12==12.9.41
+ nvidia-nvshmem-cu12==3.2.5
+ opt-einsum==3.4.0
+ scipy==1.15.3
(.venv) $ python
Python 3.13.3 (main, Apr 9 2025, 04:03:52) [Clang 20.1.0 ] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
>>> jax.numpy.zeros(3)
ERROR:2025-05-22 12:51:37,974:jax._src.xla_bridge:444: Jax plugin configuration error: Exception when calling jax_plugins.xla_cuda12.initialize()
Traceback (most recent call last):
File ".venv/lib/python3.13/site-packages/jax_plugins/xla_cuda12/__init__.py", line 135, in _version_check
version = get_version()
RuntimeError: jaxlib/cuda/versions_helpers.cc:81: operation cusparseGetProperty(MAJOR_VERSION, &major) failed: The cuSPARSE library was not found.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File ".venv/lib/python3.13/site-packages/jax/_src/xla_bridge.py", line 442, in discover_pjrt_plugins
plugin_module.initialize()
~~~~~~~~~~~~~~~~~~~~~~~~^^
File ".venv/lib/python3.13/site-packages/jax_plugins/xla_cuda12/__init__.py", line 230, in initialize
_check_cuda_versions(raise_on_first_error=True)
~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".venv/lib/python3.13/site-packages/jax_plugins/xla_cuda12/__init__.py", line 199, in _check_cuda_versions
_version_check("cuSPARSE", cuda_versions.cusparse_get_version,
~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
cuda_versions.cusparse_build_version,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# Ignore patch versions.
^^^^^^^^^^^^^^^^^^^^^^^^
scale_for_comparison=100,
^^^^^^^^^^^^^^^^^^^^^^^^^
min_supported_version=12100)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".venv/lib/python3.13/site-packages/jax_plugins/xla_cuda12/__init__.py", line 139, in _version_check
raise RuntimeError(err_msg) from e
RuntimeError: Unable to load cuSPARSE. Is it installed?
WARNING:2025-05-22 12:51:38,032:jax._src.xla_bridge:791: An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
Array([0., 0., 0.], dtype=float32)
System info (python version, jaxlib version, accelerator, etc.)
>>> import jax; jax.print_environment_info()
jax: 0.6.1
jaxlib: 0.6.1
numpy: 2.2.6
python: 3.13.3 (main, Apr 9 2025, 04:03:52) [Clang 20.1.0 ]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='***', release='5.15.0-1018-gcp-tcpx', version='#18-Ubuntu SMP Fri Jul 26 14:21:24 UTC 2024', machine='x86_64')
$ nvidia-smi
Thu May 22 12:57:55 2025
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05 Driver Version: 535.104.05 CUDA Version: 12.2 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA H100 80GB HBM3 On | 00000000:04:00.0 Off | 0 |
| N/A 42C P0 73W / 700W | 4MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
| No running processes found |
+---------------------------------------------------------------------------------------+