Skip to content

debug.print inside scan gets lost during AD #28738

Open
@dfm

Description

@dfm

Description

import jax

def f(c, _):
  jax.debug.print("c = {c}", c=c, ordered=True)
  return c + 1, None

def g(x):
  return jax.lax.scan(f, x, length=2)[0]

jax.make_jaxpr(jax.value_and_grad(g))(1.0)

prints

{ lambda ; a:f32[]. let
    b:f32[] = scan[
      _split_transpose=False
      jaxpr={ lambda ; c:f32[]. let d:f32[] = add c 1.0:f32[] in (d,) }
      length=2
      linear=(False,)
      num_carry=1
      num_consts=0
      reverse=False
      unroll=1
    ] a
    e:f32[] = scan[
      _split_transpose=False
      jaxpr={ lambda ; f:f32[]. let  in (f,) }
      length=2
      linear=(True,)
      num_carry=1
      num_consts=0
      reverse=True
      unroll=1
    ] 1.0:f32[]
  in (b, e) }

Where the debug.print has been dropped from the forward pass.

I narrowed it down to getting lost here:

jaxpr_known_, invar_pvals_out, jaxpr_known_consts = pe.trace_to_jaxpr_nounits(
lu.wrap_init(core.jaxpr_as_fun(jaxpr_known),
debug_info=jaxpr_known.jaxpr.debug_info),
const_pvals + other_pvals,
instantiate=[True] * (len(out_uk) - sum(out_uk)) + [False] * num_res)

Because tracers_to_jaxpr won't capture these eventful eqns.

AFAICT, this hasn't ever worked (I tried it back until 0.4.6 which is the oldest jaxlib on PyPI), but it should!

System info (python version, jaxlib version, accelerator, etc.)

...

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions