Skip to content

improve error message and docs for custom_jvp / custom_vjp nondiff_argnums escaped tracer error #20889

Open
@mattjj

Description

@mattjj

When using custom_jvp or custom_vjp, don't use nondiff_argnums for array-valued arguments. It'll often lead to "encountered an unexpected tracer" errors.

But we should raise a better error, and make the docs more discoverable (and clearer).

Here's a repro from a user:

import jax
from jax import numpy as jnp

def func_fwd(arr, mask):
  return arr * mask

def func_jvp(mask, primals, tangents):
  def f(arr):
    return arr * mask
  return jax.jvp(f, primals, tangents)

func = jax.custom_jvp(func_fwd, nondiff_argnums=(1,))
func.defjvp(func_jvp)

def step(carry, _):
  return (func(*carry), carry[1]), None

def loss(x, mask):
  carry, _ = jax.lax.scan(step, (x, mask), [None] * 2, length = 2)
  return carry[0].sum()

x = jnp.ones(10)
mask = jnp.zeros(10, dtype=bool)


gradients = jax.grad(loss)(x, mask)

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions