Skip to content

Installing jax[cuda12]==0.6.1 leads to a RuntimeError because the cuSPARSE library cannot be found #28929

Closed
@maxencefaldor

Description

@maxencefaldor

Description

Installing jax[cuda12]==0.6.1 leads to a RuntimeError because the cuSPARSE library cannot be found, causing JAX to fall back to CPU execution despite the presence of CUDA-enabled jaxlib components.

Here is how to reproduce:

$ uv venv
Using CPython 3.13.3
Creating virtual environment at: .venv
Activate with: source .venv/bin/activate

$ source .venv/bin/activate

(.venv) $ uv pip install -U "jax[cuda12]"
Using Python 3.13.3 environment at: .venv
Resolved 19 packages in 304ms
Prepared 19 packages in 37ms
Installed 19 packages in 6.00s
 + jax==0.6.1
 + jax-cuda12-pjrt==0.6.1
 + jax-cuda12-plugin==0.6.1
 + jaxlib==0.6.1
 + ml-dtypes==0.5.1
 + numpy==2.2.6
 + nvidia-cublas-cu12==12.8.4.1
 + nvidia-cuda-cupti-cu12==12.9.19
 + nvidia-cuda-nvcc-cu12==12.9.41
 + nvidia-cuda-runtime-cu12==12.9.37
 + nvidia-cudnn-cu12==9.10.1.4
 + nvidia-cufft-cu12==11.4.0.6
 + nvidia-cusolver-cu12==11.7.4.40
 + nvidia-cusparse-cu12==12.5.9.5
 + nvidia-nccl-cu12==2.26.5
 + nvidia-nvjitlink-cu12==12.9.41
 + nvidia-nvshmem-cu12==3.2.5
 + opt-einsum==3.4.0
 + scipy==1.15.3

(.venv) $ python
Python 3.13.3 (main, Apr  9 2025, 04:03:52) [Clang 20.1.0 ] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
>>> jax.numpy.zeros(3)
ERROR:2025-05-22 12:51:37,974:jax._src.xla_bridge:444: Jax plugin configuration error: Exception when calling jax_plugins.xla_cuda12.initialize()
Traceback (most recent call last):
  File ".venv/lib/python3.13/site-packages/jax_plugins/xla_cuda12/__init__.py", line 135, in _version_check
    version = get_version()
RuntimeError: jaxlib/cuda/versions_helpers.cc:81: operation cusparseGetProperty(MAJOR_VERSION, &major) failed: The cuSPARSE library was not found.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File ".venv/lib/python3.13/site-packages/jax/_src/xla_bridge.py", line 442, in discover_pjrt_plugins
    plugin_module.initialize()
    ~~~~~~~~~~~~~~~~~~~~~~~~^^
  File ".venv/lib/python3.13/site-packages/jax_plugins/xla_cuda12/__init__.py", line 230, in initialize
    _check_cuda_versions(raise_on_first_error=True)
    ~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.13/site-packages/jax_plugins/xla_cuda12/__init__.py", line 199, in _check_cuda_versions
    _version_check("cuSPARSE", cuda_versions.cusparse_get_version,
    ~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                   cuda_versions.cusparse_build_version,
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                   # Ignore patch versions.
                   ^^^^^^^^^^^^^^^^^^^^^^^^
                   scale_for_comparison=100,
                   ^^^^^^^^^^^^^^^^^^^^^^^^^
                   min_supported_version=12100)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.13/site-packages/jax_plugins/xla_cuda12/__init__.py", line 139, in _version_check
    raise RuntimeError(err_msg) from e
RuntimeError: Unable to load cuSPARSE. Is it installed?
WARNING:2025-05-22 12:51:38,032:jax._src.xla_bridge:791: An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
Array([0., 0., 0.], dtype=float32)

System info (python version, jaxlib version, accelerator, etc.)

>>> import jax; jax.print_environment_info()
jax:    0.6.1
jaxlib: 0.6.1
numpy:  2.2.6
python: 3.13.3 (main, Apr  9 2025, 04:03:52) [Clang 20.1.0 ]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='***', release='5.15.0-1018-gcp-tcpx', version='#18-Ubuntu SMP Fri Jul 26 14:21:24 UTC 2024', machine='x86_64')

$ nvidia-smi
Thu May 22 12:57:55 2025       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA H100 80GB HBM3          On  | 00000000:04:00.0 Off |                    0 |
| N/A   42C    P0             73W / 700W |  4MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    No running processes found |
+---------------------------------------------------------------------------------------+

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions