Skip to content

lax.cond crashes on Windows #29049

Open
Open
@lnabergall

Description

@lnabergall

Description

On Windows 11 and JAX 0.6.1 the following code

import jax
import jax.numpy as jnp
A = jax.lax.cond(jnp.bool(True), lambda: jnp.float32(0), lambda: jnp.float32(0))

crashes Python, throwing the error

2025-05-28 02:29:07.453773: F external/xla/xla/hlo/ir/hlo_instruction.cc:3617] Check failed: PRED == operand(0)->shape().element_type() (1 vs. 4)

The same code on WSL2 runs without issue. Any idea what the problem is?

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.6.1
jaxlib: 0.6.1
numpy:  2.2.6
python: 3.10.8 | packaged by conda-forge | (main, Nov 24 2022, 14:07:00) [MSC v.1916 64 bit (AMD64)]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Windows', node='wizardhat', release='10', version='10.0.26100', machine='AMD64')

$ nvidia-smi
Wed May 28 02:37:56 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 556.12                 Driver Version: 556.12         CUDA Version: 12.5     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                  Driver-Model | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce RTX 3050 ...  WDDM  |   00000000:01:00.0 Off |                  N/A |
| N/A   50C    P8              3W /   40W |    3155MiB /   4096MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
+-----------------------------------------------------------------------------------------+

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions