Skip to content

ENH: lazy_xp_function support for wrapped return values #270

@crusaderky

Description

@crusaderky

scipy.stats.ttest_ind fails when wrapped with lazy_xp_function.
The reason is that the function returns a custom subclass of NamedTuple.
@jax.jit automatically repacks returned lists, tuples, and NamedTuples of arrays, but fails with custom classes.
This however is an artefact specific of lazy_xp_function; in real life, end users will unpack and consume the return value of ttest_ind within the scope of the jit.

In other words, this fails:

from scipy.stats import ttest_ind
xpx.lazy_xp_function(ttest_ind)

def test1(xp):
    x = xp.asarray([.1, .2])
    y = xp.asarray([.3])
    res = ttest_ind(x, y)
    # res = TtestResult(statistic=np.float64(-1.7320508075688765), pvalue=np.float64(0.3333333333333334), df=np.float64(1.0))

FAILED test1.py::test1[jax.numpy] - TypeError: TtestResult.__init__() missing 4 required positional arguments: 'df', 'alternative', 'standard_error', and 'estimate'

as it is equivalent to

>>> from scipy.stats import ttest_ind
>>> jitted = jax.jit(ttest_ind)
>>> jitted(jnp.asarray([.1, .2]), jnp.asarray([.3])
TypeError: TtestResult.__init__() missing 4 required positional arguments: 'df', 'alternative', 'standard_error', and 'estimate'

However, in real-life users will not write the above, but will write instead something like

>>> from scipy.stats import ttest_ind
>>> @jax.jit
>>> def f(x, y):
...     res = ttest_ind(x, y)
...     # Stand-in for some post-processing
...     return res.statistic, res.pvalue, res.df)
>>> f(jnp.asarray([.1, .2]), jnp.asarray([.3])
(Array(-1.7320508, dtype=float32),
 Array(0.3333333, dtype=float32),
 Array(1., dtype=float32))

Proposed design

Use pickle hooks to automatically unpack and repack complex return values.

Metadata

Metadata

Assignees

No one assigned

    Projects

    No projects

    Milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions