Open
Description
Description
On Windows 11 and JAX 0.6.1 the following code
import jax
import jax.numpy as jnp
A = jax.lax.cond(jnp.bool(True), lambda: jnp.float32(0), lambda: jnp.float32(0))
crashes Python, throwing the error
2025-05-28 02:29:07.453773: F external/xla/xla/hlo/ir/hlo_instruction.cc:3617] Check failed: PRED == operand(0)->shape().element_type() (1 vs. 4)
The same code on WSL2 runs without issue. Any idea what the problem is?
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.6.1
jaxlib: 0.6.1
numpy: 2.2.6
python: 3.10.8 | packaged by conda-forge | (main, Nov 24 2022, 14:07:00) [MSC v.1916 64 bit (AMD64)]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Windows', node='wizardhat', release='10', version='10.0.26100', machine='AMD64')
$ nvidia-smi
Wed May 28 02:37:56 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 556.12 Driver Version: 556.12 CUDA Version: 12.5 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Driver-Model | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA GeForce RTX 3050 ... WDDM | 00000000:01:00.0 Off | N/A |
| N/A 50C P8 3W / 40W | 3155MiB / 4096MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
+-----------------------------------------------------------------------------------------+