Skip to content

ENH: lazy_xp_function namespaces support #158

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 73 additions & 28 deletions src/array_api_extra/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,8 @@ def lazy_xp_function( # type: ignore[no-any-explicit]
"""
Tag a function to be tested on lazy backends.

Tag a function, which must be imported in the test module globals, so that when any
tests defined in the same module are executed with ``xp=jax.numpy`` the function is
replaced with a jitted version of itself, and when it is executed with
Tag a function so that when any tests are executed with ``xp=jax.numpy`` the
function is replaced with a jitted version of itself, and when it is executed with
``xp=dask.array`` the function will raise if it attempts to materialize the graph.
This will be later expanded to provide test coverage for other lazy backends.

Expand Down Expand Up @@ -120,19 +119,59 @@ def test_myfunc(xp):

Notes
-----
A test function can circumvent this monkey-patching system by calling `func` as an
attribute of the original module. You need to sanitize your code to make sure this
does not happen.
In order for this tag to be effective, the test function must be imported into the
test module globals without its namespace; alternatively its namespace must be
declared in a ``lazy_xp_modules`` list in the test module globals.

Example::
Example 1::

import mymodule from mymodule import myfunc
from mymodule import myfunc

lazy_xp_function(myfunc)

def test_myfunc(xp):
a = xp.asarray([1, 2]) b = myfunc(a) # This is jitted when xp=jax.numpy c =
mymodule.myfunc(a) # This is not
x = myfunc(xp.asarray([1, 2]))

Example 2::

import mymodule

lazy_xp_modules = [mymodule]
lazy_xp_function(mymodule.myfunc)

def test_myfunc(xp):
x = mymodule.myfunc(xp.asarray([1, 2]))

A test function can circumvent this monkey-patching system by using a namespace
outside of the two above patterns. You need to sanitize your code to make sure this
only happens intentionally.

Example 1::

import mymodule
from mymodule import myfunc

lazy_xp_function(myfunc)

def test_myfunc(xp):
a = xp.asarray([1, 2])
b = myfunc(a) # This is wrapped when xp=jax.numpy or xp=dask.array
c = mymodule.myfunc(a) # This is not

Example 2::

import mymodule

class naked:
myfunc = mymodule.myfunc

lazy_xp_modules = [mymodule]
lazy_xp_function(mymodule.myfunc)

def test_myfunc(xp):
a = xp.asarray([1, 2])
b = mymodule.myfunc(a) # This is wrapped when xp=jax.numpy or xp=dask.array
c = naked.myfunc(a) # This is not
"""
tags = {
"allow_dask_compute": allow_dask_compute,
Expand All @@ -153,11 +192,13 @@ def patch_lazy_xp_functions(
Test lazy execution of functions tagged with :func:`lazy_xp_function`.

If ``xp==jax.numpy``, search for all functions which have been tagged with
:func:`lazy_xp_function` in the globals of the module that defines the current test
:func:`lazy_xp_function` in the globals of the module that defines the current test,
as well as in the ``lazy_xp_modules`` list in the globals of the same module,
and wrap them with :func:`jax.jit`. Unwrap them at the end of the test.

If ``xp==dask.array``, wrap the functions with a decorator that disables
``compute()`` and ``persist()``.
``compute()`` and ``persist()`` and ensures that exceptions and warnings are raised
eagerly.

This function should be typically called by your library's `xp` fixture that runs
tests on multiple backends::
Expand All @@ -183,29 +224,33 @@ def xp(request, monkeypatch):
lazy_xp_function : Tag a function to be tested on lazy backends.
pytest.FixtureRequest : `request` test function parameter.
"""
globals_ = cast("dict[str, Any]", request.module.__dict__) # type: ignore[no-any-explicit]

def iter_tagged() -> Iterator[tuple[str, Callable[..., Any], dict[str, Any]]]: # type: ignore[no-any-explicit]
for name, func in globals_.items():
tags: dict[str, Any] | None = None # type: ignore[no-any-explicit]
with contextlib.suppress(AttributeError):
tags = func._lazy_xp_function # pylint: disable=protected-access
if tags is None:
with contextlib.suppress(KeyError, TypeError):
tags = _ufuncs_tags[func]
if tags is not None:
yield name, func, tags
mod = cast(ModuleType, request.module)
mods = [mod, *cast(list[ModuleType], getattr(mod, "lazy_xp_modules", []))]

def iter_tagged() -> ( # type: ignore[no-any-explicit]
Iterator[tuple[ModuleType, str, Callable[..., Any], dict[str, Any]]]
):
for mod in mods:
for name, func in mod.__dict__.items():
tags: dict[str, Any] | None = None # type: ignore[no-any-explicit]
with contextlib.suppress(AttributeError):
tags = func._lazy_xp_function # pylint: disable=protected-access
if tags is None:
with contextlib.suppress(KeyError, TypeError):
tags = _ufuncs_tags[func]
if tags is not None:
yield mod, name, func, tags

if is_dask_namespace(xp):
for name, func, tags in iter_tagged():
for mod, name, func, tags in iter_tagged():
n = tags["allow_dask_compute"]
wrapped = _dask_wrap(func, n)
monkeypatch.setitem(globals_, name, wrapped)
monkeypatch.setattr(mod, name, wrapped)

elif is_jax_namespace(xp):
import jax

for name, func, tags in iter_tagged():
for mod, name, func, tags in iter_tagged():
if tags["jax_jit"]:
# suppress unused-ignore to run mypy in -e lint as well as -e dev
wrapped = cast( # type: ignore[no-any-explicit]
Expand All @@ -216,7 +261,7 @@ def iter_tagged() -> Iterator[tuple[str, Callable[..., Any], dict[str, Any]]]:
static_argnames=tags["static_argnames"],
),
)
monkeypatch.setitem(globals_, name, wrapped)
monkeypatch.setattr(mod, name, wrapped)


class CountingDaskScheduler(SchedulerGetCallable):
Expand Down
37 changes: 37 additions & 0 deletions tests/test_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def non_materializable(x: Array) -> Array:
and it will trigger an expensive computation in dask.
"""
xp = array_namespace(x)
# Crashes inside jax.jit
# On dask, this triggers two computations of the whole graph
if xp.any(x < 0.0) or xp.any(x > 10.0):
msg = "Values must be in the [0, 10] range"
Expand Down Expand Up @@ -261,3 +262,39 @@ def test_lazy_xp_function_eagerly_raises(da: ModuleType):
x = da.arange(3)
with pytest.raises(ValueError, match="Hello world"):
dask_raises(x)


class Wrapped:
def f(x: Array) -> Array: # noqa: N805 # pyright: ignore[reportSelfClsParameterName]
xp = array_namespace(x)
# Crash in jax.jit and trigger compute() on dask
if not xp.all(x):
msg = "Values must be non-zero"
raise ValueError(msg)
return x


class Naked:
f = Wrapped.f # pyright: ignore[reportUnannotatedClassAttribute]


lazy_xp_function(Wrapped.f)
lazy_xp_modules = [Wrapped]


def test_lazy_xp_modules(xp: ModuleType, library: Backend):
x = xp.asarray([1.0, 2.0])
y = Naked.f(x)
xp_assert_equal(y, x)

if library is Backend.JAX:
with pytest.raises(
TypeError, match="Attempted boolean conversion of traced array"
):
Wrapped.f(x)
elif library is Backend.DASK:
with pytest.raises(AssertionError, match=r"dask\.compute"):
Wrapped.f(x)
else:
y = Wrapped.f(x)
xp_assert_equal(y, x)