Skip to content

jax.numpy: make standard input utilities respect __jax_array__ #28177

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 23, 2025

Conversation

jakevdp
Copy link
Collaborator

@jakevdp jakevdp commented Apr 22, 2025

This will be important because currently, __jax_array__ is handled by many NumPy APIs as a side-effect of it being handled in the JIT abstractification pass, but in the future we will likely remove this handling.

Before:

$ JAX_DISABLE_JIT=true pytest -n auto tests/array_extensibility_test.py
...
===================== 168 failed, 134 passed in 10.04s =====================

After:

$ JAX_DISABLE_JIT=true pytest -n auto tests/array_extensibility_test.py
...
====================== 87 failed, 215 passed in 9.40s ======================

@jakevdp jakevdp requested a review from cgarciae April 22, 2025 17:42
@jakevdp jakevdp self-assigned this Apr 22, 2025
@jakevdp jakevdp added the pull ready Ready for copybara import and testing label Apr 22, 2025
@copybara-service copybara-service bot merged commit aeb86f1 into jax-ml:main Apr 23, 2025
23 checks passed
@jakevdp jakevdp deleted the ufunc-jax-array branch April 23, 2025 21:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
kokoro:force-run pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants