Closed
Description
Description
repro:
@jax.custom_jvp
def f(a, b):
return b
def g_bwd(res, g):
return (None, jnp.zeros_like(res[1]))
@jax.custom_transpose.custom_transpose
def g_tan(res, dx):
return dx[1]
g_tan.def_transpose(g_bwd)
def fwd(x, dx):
a, b = x
y = f(*x)
res = (a, b, 1)
tan_out_types = jax.typeof(y)
return y, g_tan(tan_out_types, res, dx)
f.defjvp(fwd)
@jax.jit
def test(a, b):
return jax.grad(f, argnums=(1,))(a, b)
test(None, jnp.float32(2)) # raises ValueError: foreach() argument 2 is longer than argument 1
I don't know exactly how the custom_transpose logic works but my guess is that the code is using None as a sentinel on a user provided pytree. Which I think happens here:
https://github.com/jax-ml/jax/blob/main/jax/_src/custom_transpose.py#L228
System info (python version, jaxlib version, accelerator, etc.)
Not applicable