Open
Description
When using custom_jvp
or custom_vjp
, don't use nondiff_argnums
for array-valued arguments. It'll often lead to "encountered an unexpected tracer" errors.
But we should raise a better error, and make the docs more discoverable (and clearer).
Here's a repro from a user:
import jax
from jax import numpy as jnp
def func_fwd(arr, mask):
return arr * mask
def func_jvp(mask, primals, tangents):
def f(arr):
return arr * mask
return jax.jvp(f, primals, tangents)
func = jax.custom_jvp(func_fwd, nondiff_argnums=(1,))
func.defjvp(func_jvp)
def step(carry, _):
return (func(*carry), carry[1]), None
def loss(x, mask):
carry, _ = jax.lax.scan(step, (x, mask), [None] * 2, length = 2)
return carry[0].sum()
x = jnp.ones(10)
mask = jnp.zeros(10, dtype=bool)
gradients = jax.grad(loss)(x, mask)