-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Apple Silicon: error: failed to legalize operation 'mhlo.cholesky' #16321
Comments
Can confirm that this is still broken in version 0.0.5 |
Any update on ETA here? I am trying to use Brax on Metal and it wants the cholesky decomp. |
Looking into add the conversion of the op. |
I just wanted to mark that it's still not implemented in version 0.0.6 in case anyone noticed the new release |
I'm also eagerly awaiting this |
Would love to use multivariate normal distributions which depends on the Cholesky decomposition. Am eagerly awaiting this. |
Still not working in jax-metal v0.0.7 |
We're approaching the one year mark on this. Any hope that this would be resolved soon? |
Is jax-metal open source? I can’t find it but would consider contributing. |
As far as I know its maintained by people at Apple (@kulinseth). I believe they don't share their code. |
I can report that v0.1.0 still does not address this |
After looking at the code, I found that cholesky is defined in Jaxlib. It seems that inserting the metal backend through pjrt cannot solve this problem? |
Problem persist with |
I am getting the error as well. I would love to see use jax-metal, but that is impossible if is does not cover basic linalg such as 'mhlo.cholesky'. :/ |
Unfortunately with Apple having their own ML framework that natively supports M chips (which is very similar to JAX), it seems that getting full metal compatibility with JAX is low priority. |
This still doesn't work the latest mac os and jax-metal v0.1.1 |
Would be a big thing for the robotics community to have this resolved. Really want to have Mujoco MJX running on Apple Silicon 😥 |
@shuhand0 @kulinseth Any updates on this? My robot is aching to learn faster. |
@shuhand0 @kulinseth Any updates on this? |
Description
After building jaxlib as per the instructions and installing jax-metal, upon testing with an existing model which works fine using CPU (and GPU on linux), I get the following error.
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: /Users/adam/Developer/Pycharm Projects/gpy_flow_test/sparse_gps.py:66:0: error: failed to legalize operation 'mhlo.cholesky'
/Users/adam/Developer/Pycharm Projects/gpy_flow_test/sparse_gps.py:66:0: note: called from
/Users/adam/Developer/Pycharm Projects/gpy_flow_test/sparse_gps.py:66:0: note: see current operation: %406 = "mhlo.cholesky"(%405) {lower = true} : (tensor<50x50xf32>) -> tensor<50x50xf32>
The full error message is very low, and is attached here.
cholesky_full_error.txt.zip
I did try a minimal example shown below which also calls the cholesky operator, but I couldn't reproduce the same error. I am more than happy to try another more in-depth test code. Any suggestions?
What jax/jaxlib version are you using?
jaxlib 0.4.10 (metal), jax 0.4.11
Which accelerator(s) are you using?
CPU/GPU
Additional system info
Python v3.10.10, Apple M2
NVIDIA GPU info
No response
The text was updated successfully, but these errors were encountered: