Skip to content

custom transpose fails on user provided None values #29009

Closed
@jheek

Description

@jheek

Description

repro:

@jax.custom_jvp
def f(a, b):
  return b


def g_bwd(res, g):
  return (None, jnp.zeros_like(res[1]))


@jax.custom_transpose.custom_transpose
def g_tan(res, dx):
  return dx[1]

g_tan.def_transpose(g_bwd)

def fwd(x, dx):
  a, b = x
  y = f(*x)
  res = (a, b, 1)
  tan_out_types = jax.typeof(y)
  return y, g_tan(tan_out_types, res, dx)

f.defjvp(fwd)

@jax.jit
def test(a, b):
  return jax.grad(f, argnums=(1,))(a, b)

test(None, jnp.float32(2))  # raises ValueError: foreach() argument 2 is longer than argument 1

I don't know exactly how the custom_transpose logic works but my guess is that the code is using None as a sentinel on a user provided pytree. Which I think happens here:

https://github.com/jax-ml/jax/blob/main/jax/_src/custom_transpose.py#L228

System info (python version, jaxlib version, accelerator, etc.)

Not applicable

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions