Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apple Silicon: error: failed to legalize operation 'mhlo.cholesky' #16321

Open
adam-hartshorne opened this issue Jun 8, 2023 · 21 comments
Open
Assignees
Labels

Comments

@adam-hartshorne
Copy link

adam-hartshorne commented Jun 8, 2023

Description

After building jaxlib as per the instructions and installing jax-metal, upon testing with an existing model which works fine using CPU (and GPU on linux), I get the following error.

jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: /Users/adam/Developer/Pycharm Projects/gpy_flow_test/sparse_gps.py:66:0: error: failed to legalize operation 'mhlo.cholesky'
/Users/adam/Developer/Pycharm Projects/gpy_flow_test/sparse_gps.py:66:0: note: called from
/Users/adam/Developer/Pycharm Projects/gpy_flow_test/sparse_gps.py:66:0: note: see current operation: %406 = "mhlo.cholesky"(%405) {lower = true} : (tensor<50x50xf32>) -> tensor<50x50xf32>

The full error message is very low, and is attached here.

cholesky_full_error.txt.zip

I did try a minimal example shown below which also calls the cholesky operator, but I couldn't reproduce the same error. I am more than happy to try another more in-depth test code. Any suggestions?

from jax import jit
import jax.numpy as jnp
import jax.random as jnr
import jax.scipy as jsp

key = jnr.PRNGKey(0)
A = jnr.normal(key, (100,100))

def calc_cholesky_decomp(test_matrix):
    psd_test_matrix = test_matrix @ test_matrix.T
    col_decomp = jsp.linalg.cholesky(psd_test_matrix, lower=True)
    return col_decomp

calc_cholesky_decomp(A)

jitted_calc_cholesky_decomp = jit(calc_cholesky_decomp)
jitted_calc_cholesky_decomp(A)

What jax/jaxlib version are you using?

jaxlib 0.4.10 (metal), jax 0.4.11

Which accelerator(s) are you using?

CPU/GPU

Additional system info

Python v3.10.10, Apple M2

NVIDIA GPU info

No response

@adam-hartshorne adam-hartshorne added the bug Something isn't working label Jun 8, 2023
@hawkinsp
Copy link
Collaborator

hawkinsp commented Jun 8, 2023

@shuhand0 @kulinseth

@hawkinsp hawkinsp added enhancement New feature or request and removed bug Something isn't working labels Jun 15, 2023
@benjaminvatterj
Copy link

Can confirm that this is still broken in version 0.0.5

@c0g
Copy link

c0g commented Mar 1, 2024

Any update on ETA here? I am trying to use Brax on Metal and it wants the cholesky decomp.

@kulinseth

@shuhand0
Copy link
Collaborator

shuhand0 commented Mar 6, 2024

Looking into add the conversion of the op.

@benjaminvatterj
Copy link

I just wanted to mark that it's still not implemented in version 0.0.6 in case anyone noticed the new release

@vhaasteren
Copy link

I'm also eagerly awaiting this

@mvanaltvorst
Copy link

mvanaltvorst commented Apr 4, 2024

Would love to use multivariate normal distributions which depends on the Cholesky decomposition. Am eagerly awaiting this.

@driesmarzougui
Copy link

Still not working in jax-metal v0.0.7

@benjaminvatterj
Copy link

We're approaching the one year mark on this. Any hope that this would be resolved soon?

@c0g
Copy link

c0g commented May 15, 2024

Is jax-metal open source? I can’t find it but would consider contributing.

@benjaminvatterj
Copy link

As far as I know its maintained by people at Apple (@kulinseth). I believe they don't share their code.

@vhaasteren
Copy link

I can report that v0.1.0 still does not address this

@yangfengzzz
Copy link

I found WWDC24 show Jax support Mujoco, But I try MJX, it will still cause this issue problem.
截屏2024-06-15 06 51 07

@yangfengzzz
Copy link

After looking at the code, I found that cholesky is defined in Jaxlib. It seems that inserting the metal backend through pjrt cannot solve this problem?

@mmattamala
Copy link

Problem persist with jax-metal v0.1.0, jax v0.4.31 and jaxlib v0.4.31

@ojwenzel
Copy link

ojwenzel commented Oct 2, 2024

I am getting the error as well. I would love to see use jax-metal, but that is impossible if is does not cover basic linalg such as 'mhlo.cholesky'. :/

@adam-hartshorne
Copy link
Author

adam-hartshorne commented Oct 2, 2024

Unfortunately with Apple having their own ML framework that natively supports M chips (which is very similar to JAX), it seems that getting full metal compatibility with JAX is low priority.

@adam-hartshorne
Copy link
Author

adam-hartshorne commented Oct 12, 2024

This still doesn't work the latest mac os and jax-metal v0.1.1

@HaoruXue
Copy link

Would be a big thing for the robotics community to have this resolved. Really want to have Mujoco MJX running on Apple Silicon 😥

@Stocko-2073
Copy link

@shuhand0 @kulinseth Any updates on this? My robot is aching to learn faster.

@iamruiyang
Copy link

@shuhand0 @kulinseth Any updates on this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests