Skip to content

ENH: jax_autojit #284

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 16, 2025
Merged

ENH: jax_autojit #284

merged 5 commits into from
May 16, 2025

Conversation

crusaderky
Copy link
Contributor

@crusaderky crusaderky commented Apr 25, 2025

Supercharge xpx.testing.lazy_xp_function.

This is a fairly high level change; @rgommers @pearu @jakevdp I could use your feedback.

Matching SciPy PR: scipy/scipy#22909

Automatic static arguments

scipy has seen a proliferation of lazy_xp_function(func, static_argnames=(...)) in the initial section of many test modules. This information is only useful for JAX and is quite verbose.
With scipy/scipy#22686, this informational quirk that is both specific to JAX and to unit tests moves to the implementation modules.

With this PR, lazy_xp_function no longer accepts parameters static_argnames and static_argnums. Instead, all arguments that are not JAX arrays are automatically treated as static. Note that this behaviour is generally desirable in unit testing but a bad idea in production. Consider:

def f(x: Array, y: float, plus: bool) -> Array:
    return x + y if plus else x - y

j1 = jax.jit(f, static_argnames="plus")
j2 = jax_autojit(f)

jax_autojit is a new internal function of array-api-extra.
In the above example, j2 requires a lot less setup to be tested effectively, but on the flip side it means that it will be re-traced for every different value of y, which likely makes it not fit for purpose in production.

To clarify: jax_autojit is applied and removed on the fly within tests with the xp fixture and it is never used outside of unit testing.

Wrapped inputs and output

There are a few cases of scipy functions returning bespoke containers with arrays inside instead of simple tuples, namedtuples, lists, or dicts of arrays.

Two such examples are

In main, these can't be tested with lazy_xp_function, and as of scipy/scipy#22686 the issue also impacts documentation.

This PR lifts this restriction and allows completely arbitrary objects as parameters and as return values of the functions. If these objects internally contain JAX arrays, lazy_xp_function will now automatically extract them, pass them through the JIT, and reassemble everything for the test function to observe the result.

The rationale is that, in real life, users are unlikely to wrap the scipy functions with jax.jit directly; instead they are more likely to consume their outputs in their own functions and then wrap those with jax.jit:

@jax.jit
def my_user_function(x):
    y = scipy.stats.ttest_ind(x)
    return my_user_consume(y)  # returns array or tuple of arrays

Static non-hashable arguments

Non-hashable objects can now be static.
Consider:

def f(x: Array, verbs: Sequence[str]) -> Array:
    ...
    
lazy_xp_function(f, static_argnames="verbs")

def test_f(xp):
    x = ...
    y = f(x, ("foo", "bar", "baz"))  # OK
    y = f(x, ["foo", "bar", "baz"])  # Fails

This PR fixes it; all objects in the input and output of the function need only be hashable or pickleable.

Dask materialization raises inside a container

This also fixes a Dask-specific bug where the graph materialization raises:

@pytest.mark.skip_xp_backends("jax.numpy", reason="raise inside jax.pure_callback")
def test1(xp):
    with pytest.raises(Exception):
        f(x)

The above works when f returns plain Dask array objects or tuples or lists thereof, even if the graph would otherwise be discarded and never computed, thanks to the lazy_xp_function machinery, but used to fail when the return value is an opaque container with Dask arrays inside. This PR fixes it.

Comment on lines +509 to +512
"""
Register upon first use instead of at import time, to avoid
globally importing JAX.
"""
Copy link
Contributor Author

@crusaderky crusaderky Apr 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aside note on design: Dask avoids this exact problem by not requiring any decorator and instead duck-type checking for uniquely named dunder methods called __dask_<...>__

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting design tradeoffs there - to me the "eager by default, opt-in for graph mode" is nicer and has won in array land though (dataframes are a different story). I guess this code pattern is one of the prices to pay.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: everything in this module is a private helper.

Copy link

@pearu pearu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a minor nit and a note, otherwise it looks good to me as it reduces the maintenance burden of adding/updating static_argnames/argnums arguments.

Btw, I enjoyed reading pickle_flatten and pickle_flatten implementations as it felt like less is more.

Thanks, @crusaderky!

@jakevdp
Copy link

jakevdp commented Apr 27, 2025

For what it's worth, one of the main reasons we haven't implemented something like this in JAX is because it tends to negatively impact dispatch times. The output of jax.jit is a C++ level callable, that directly dispatches to the compiled kernel after the initial call. This is important because if you put a layer of Python logic in the dispatch path to inspect all the arguments and determine whether to treat them as static or not (and possibly re-compile based on this information) then it severely impacts dispatch times.

With that background, I'd like a bit of clarification here: is this intended for use only in testing paths, or do you imagine this will be used within the dispatch path for libraries like scipy that implement the Array API?

@lucascolley
Copy link
Member

I think this section of the top post answers your question, Jake:

To clarify: jax_autojit is applied and removed on the fly within tests with the xp fixture and it is never used outside of unit testing.

crusaderky and others added 2 commits April 28, 2025 09:50
@crusaderky
Copy link
Contributor Author

scipy/scipy#22909 does not cause any issues to crop up.
This PR is ready for final review.

@lucascolley
Copy link
Member

@crusaderky will this close gh-270?

@crusaderky
Copy link
Contributor Author

@crusaderky will this close gh-270?

Yes, it does!

@lucascolley lucascolley linked an issue May 11, 2025 that may be closed by this pull request
@lucascolley lucascolley modified the milestones: 1.0.0, 0.8.0 May 11, 2025
@lucascolley lucascolley requested review from rgommers and pearu May 12, 2025 15:12
@rgommers
Copy link
Member

@lucascolley are you happy with me merging this once I'm done reviewing, or do you want to review it as well? I'd like to get it into scipy 1.16.x if possible. This PR itself isn't quite user-facing, but the docs improvements that build on this PR are.

@lucascolley
Copy link
Member

Yes, I took a look through from a non-expert perspective and nothing seems untoward!

Copy link
Member

@rgommers rgommers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very nice, it makes testing of JAX more generic and much easier to understand for maintainers and new contributors to SciPy et al. who don't have much (or any) JAX experience. The test simplifications in this PR show big the difference is.

👍🏼 for the PR description as well, that was very useful to me - and I expect others are going to refer to it in the future.

Btw, I enjoyed reading pickle_flatten and pickle_flatten implementations as it felt like less is more.

💯 agreed

I have some small non-blocking doc suggestions only. Those can wait or be left alone; I'll go ahead and merge this PR now so we call pull it in to unblock the SciPy release.

The machinery added here is quite complex, so it's very well possible I missed something when reviewing. I tested fairly extensively with SciPy and all looks good though, so nothing major should be off here.

Thanks @crusaderky & reviewers!

Comment on lines +509 to +512
"""
Register upon first use instead of at import time, to avoid
globally importing JAX.
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting design tradeoffs there - to me the "eager by default, opt-in for graph mode" is nicer and has won in array land though (dataframes are a different story). I guess this code pattern is one of the prices to pay.

- Automatically descend into non-array return values and find ``jax.Array`` objects
inside them, then rebuild them downstream of exiting the JIT, swapping the JAX
tracer objects with concrete arrays.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These seem like a good set of choices for testing purposes to me. For the unsuspecting reader who may be looking at this code without much context, it'd be useful to add something like this:

Note: these are useful choices *for testing purposes only*, which is how this function is
intended to be used.

static_argnames : str | Iterable[str], optional
Passed to jax.jit. Named arguments to treat as static (compile-time constant).
Default: infer from `static_argnums` using `inspect.signature(func)`.
Set to True to replace `func` with a smart variant of ``jax.jit(func)`` after
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor: this description starting with "Set to True" hints at the default being False, which is not the case. I'd switch the True/False cases around and start the True case with "When set to True (default)"

Note that this isn't a new issue, it just gets magnified because the "Default: True is a lot further down".

@rgommers rgommers merged commit 28a364d into data-apis:main May 16, 2025
9 checks passed
@crusaderky crusaderky deleted the autojit branch May 16, 2025 21:06
@crusaderky crusaderky mentioned this pull request May 16, 2025
NeilGirdhar pushed a commit to NeilGirdhar/array-api-extra that referenced this pull request Jun 2, 2025
NeilGirdhar pushed a commit to NeilGirdhar/array-api-extra that referenced this pull request Jun 2, 2025
NeilGirdhar pushed a commit to NeilGirdhar/array-api-extra that referenced this pull request Jun 2, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

ENH: lazy_xp_function support for wrapped return values
6 participants