Skip to content

NotImplementedError: MLIR translation rule for primitive 'debug_callback' not found for platform METAL #28786

Open
@amavrits

Description

@amavrits

Description

I am running jax with metal and encounter the following issue (along with a much longer tracebck):

File ".../.venv/lib/python3.11/site-packages/jax/_src/interpreters/mlir.py", line 2041, in lower_per_platform
    raise NotImplementedError(
NotImplementedError: MLIR translation rule for primitive 'debug_callback' not found for platform METAL

It appears that jax-tqdm is the culprit. Upon disabling it, the issue is resolved.

System info (python version, jaxlib version, accelerator, etc.)

Python 3.11.9
jax 0.5.0
macOS Sequoia 15.5

Metadata

Metadata

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions