Open
Description
Description
Problem
Running inference using JAX 0.6.1 after disabling the thunk
runtime results in a runtime failure on Arm Neoverse CPUs, and a latency regression on x86 CPUs compared to JAX 0.6.0.
The runtime failure message: could not create a primitive descriptor for a matmul primitiveTrace/breakpoint trap (core dumped)
It seems to be stemming from OpenXLA -> oneDNN, as this is the flow that matmul follows when thunk
is disabled.
This issue was not seen on JAX 0.6.0.
Reproducer
On an Arm Neoverse V2 machine, run
pip install jax==0.6.1 transformers flax
export XLA_FLAGS=--xla_cpu_use_thunk_runtime=false
python script.py
script.py
from transformers import AutoTokenizer, FlaxAutoModelForSequenceClassification
model_name = "facebook/bart-large"
sequence_length = 8
batch_size = 32
TEXT = "Sample "
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = FlaxAutoModelForSequenceClassification.from_pretrained(model_name)
encoded_input = tokenizer([TEXT * sequence_length] * batch_size, return_tensors="np")
output = model(**encoded_input)
Running the same script with a timer on an x86 platform results in a latency regression.
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.6.1
jaxlib: 0.6.1
numpy: 2.2.6
python: 3.10.12 (main, Feb 4 2025, 14:57:36) [GCC 11.4.0]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='ip-172-31-47-163', release='6.8.0-1029-aws', version='#31~22.04.1-Ubuntu SMP Thu Apr 24 20:59:24 UTC 2025', machine='aarch64')