Open
Description
Description
import jax
def f(c, _):
jax.debug.print("c = {c}", c=c, ordered=True)
return c + 1, None
def g(x):
return jax.lax.scan(f, x, length=2)[0]
jax.make_jaxpr(jax.value_and_grad(g))(1.0)
prints
{ lambda ; a:f32[]. let
b:f32[] = scan[
_split_transpose=False
jaxpr={ lambda ; c:f32[]. let d:f32[] = add c 1.0:f32[] in (d,) }
length=2
linear=(False,)
num_carry=1
num_consts=0
reverse=False
unroll=1
] a
e:f32[] = scan[
_split_transpose=False
jaxpr={ lambda ; f:f32[]. let in (f,) }
length=2
linear=(True,)
num_carry=1
num_consts=0
reverse=True
unroll=1
] 1.0:f32[]
in (b, e) }
Where the debug.print
has been dropped from the forward pass.
I narrowed it down to getting lost here:
jax/jax/_src/lax/control_flow/loops.py
Lines 739 to 743 in af66ca9
Because tracers_to_jaxpr
won't capture these eventful eqns.
AFAICT, this hasn't ever worked (I tried it back until 0.4.6 which is the oldest jaxlib on PyPI), but it should!
System info (python version, jaxlib version, accelerator, etc.)
...