Description
scipy.stats.ttest_ind
fails when wrapped with lazy_xp_function
.
The reason is that the function returns a custom subclass of NamedTuple.
@jax.jit
automatically repacks returned lists, tuples, and NamedTuples of arrays, but fails with custom classes.
This however is an artefact specific of lazy_xp_function
; in real life, end users will unpack and consume the return value of ttest_ind
within the scope of the jit.
In other words, this fails:
from scipy.stats import ttest_ind
xpx.lazy_xp_function(ttest_ind)
def test1(xp):
x = xp.asarray([.1, .2])
y = xp.asarray([.3])
res = ttest_ind(x, y)
# res = TtestResult(statistic=np.float64(-1.7320508075688765), pvalue=np.float64(0.3333333333333334), df=np.float64(1.0))
FAILED test1.py::test1[jax.numpy] - TypeError: TtestResult.__init__() missing 4 required positional arguments: 'df', 'alternative', 'standard_error', and 'estimate'
as it is equivalent to
>>> from scipy.stats import ttest_ind
>>> jitted = jax.jit(ttest_ind)
>>> jitted(jnp.asarray([.1, .2]), jnp.asarray([.3])
TypeError: TtestResult.__init__() missing 4 required positional arguments: 'df', 'alternative', 'standard_error', and 'estimate'
However, in real-life users will not write the above, but will write instead something like
>>> from scipy.stats import ttest_ind
>>> @jax.jit
>>> def f(x, y):
... res = ttest_ind(x, y)
... # Stand-in for some post-processing
... return res.statistic, res.pvalue, res.df)
>>> f(jnp.asarray([.1, .2]), jnp.asarray([.3])
(Array(-1.7320508, dtype=float32),
Array(0.3333333, dtype=float32),
Array(1., dtype=float32))
Proposed design
Use pickle hooks to automatically unpack and repack complex return values.