Open
Description
Description
I am running jax with metal and encounter the following issue (along with a much longer tracebck):
File ".../.venv/lib/python3.11/site-packages/jax/_src/interpreters/mlir.py", line 2041, in lower_per_platform
raise NotImplementedError(
NotImplementedError: MLIR translation rule for primitive 'debug_callback' not found for platform METAL
It appears that jax-tqdm is the culprit. Upon disabling it, the issue is resolved.
System info (python version, jaxlib version, accelerator, etc.)
Python 3.11.9
jax 0.5.0
macOS Sequoia 15.5