Skip to content

Commit

Permalink
Merge pull request #26290 from mattjj:linearize-name-stack-fixes-2
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 722856808
  • Loading branch information
Google-ML-Automation committed Feb 4, 2025
2 parents 5e7e691 + 8f967c5 commit 363f1e6
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
11 changes: 9 additions & 2 deletions jax/_src/interpreters/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def linearize_jaxpr(
return _linearize_jaxpr(jaxpr, tuple(nonzeros))

@weakref_lru_cache
@source_info_util.reset_name_stack()
def _linearize_jaxpr(
jaxpr: core.ClosedJaxpr,
nonzeros: tuple[bool, ...]
Expand Down Expand Up @@ -192,7 +193,8 @@ def direct_linearize(traceable: lu.WrappedFun,
linearize_trace = LinearizeTrace(parent_trace, tangent_trace, tag=tag)
tracers = [LinearizeTracer(linearize_trace, p, t) for p, t in zip(primals, tangents)]
tracers = [t.full_lower() for t in tracers]
with core.set_current_trace(linearize_trace, check_leaks=True):
with (core.set_current_trace(linearize_trace, check_leaks=True),
source_info_util.transform_name_stack('jvp')):
if has_aux:
ans, aux = traceable.call_wrapped(*tracers)
aux_primals = [x.primal
Expand Down Expand Up @@ -587,6 +589,10 @@ def __init__(self, parent_trace, tangent_trace, tag=None):
self.tag = core.TraceTag() if tag is None else tag
self.parent_trace = parent_trace
self.tangent_trace = tangent_trace
self._name_stack_prefix_len = len(source_info_util.current_name_stack())

def _name_stack_suffix(self):
return source_info_util.current_name_stack()[self._name_stack_prefix_len:]

def to_primal_tangent_pair(self, val):
if isinstance(val, LinearizeTracer) and val._trace.tag is self.tag:
Expand All @@ -605,7 +611,8 @@ def process_primitive(self, primitive, args, params):
with core.set_current_trace(self.parent_trace):
primal_out, tangent_nzs_out, residuals, linearized = lin(
tangent_nzs, *primals_in, **params)
with core.set_current_trace(self.tangent_trace):
with (core.set_current_trace(self.tangent_trace),
source_info_util.set_name_stack(self._name_stack_suffix())):
tangent_out = linearized(residuals, *tangents_in)
if primitive.multiple_results:
return [maybe_linearize_tracer(self, x, nz, t)
Expand Down
3 changes: 2 additions & 1 deletion tests/name_stack_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,8 @@ def test_grad_should_add_jvp_and_transpose_to_call_jaxpr(self):
@jax.jit
def f(x):
with jax.named_scope('bar'):
return jnp.sin(x)
# return jnp.sin(x)
return jax.lax.sin(x)
jaxpr = jax.make_jaxpr(f)(1.).jaxpr
jaxpr_param = 'jaxpr'

Expand Down

0 comments on commit 363f1e6

Please sign in to comment.