Skip to content

Runtime failure/regression while running inference on JAX 0.6.1 #29070

Open
@Rohanjames1997

Description

@Rohanjames1997

Description

Problem

Running inference using JAX 0.6.1 after disabling the thunk runtime results in a runtime failure on Arm Neoverse CPUs, and a latency regression on x86 CPUs compared to JAX 0.6.0.

The runtime failure message: could not create a primitive descriptor for a matmul primitiveTrace/breakpoint trap (core dumped)

It seems to be stemming from OpenXLA -> oneDNN, as this is the flow that matmul follows when thunk is disabled.

This issue was not seen on JAX 0.6.0.


Reproducer

On an Arm Neoverse V2 machine, run

pip install jax==0.6.1 transformers flax
export XLA_FLAGS=--xla_cpu_use_thunk_runtime=false
python script.py

script.py
from transformers import AutoTokenizer, FlaxAutoModelForSequenceClassification

model_name = "facebook/bart-large"
sequence_length = 8
batch_size = 32
TEXT = "Sample "

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = FlaxAutoModelForSequenceClassification.from_pretrained(model_name)

encoded_input = tokenizer([TEXT * sequence_length] * batch_size, return_tensors="np")

output = model(**encoded_input)

Running the same script with a timer on an x86 platform results in a latency regression.


System info (python version, jaxlib version, accelerator, etc.)

jax:    0.6.1
jaxlib: 0.6.1
numpy:  2.2.6
python: 3.10.12 (main, Feb  4 2025, 14:57:36) [GCC 11.4.0]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='ip-172-31-47-163', release='6.8.0-1029-aws', version='#31~22.04.1-Ubuntu SMP Thu Apr 24 20:59:24 UTC 2025', machine='aarch64')

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions