Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[direct-linearize] fix name stack tests #26290

Merged
merged 1 commit into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions jax/_src/interpreters/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def linearize_jaxpr(
return _linearize_jaxpr(jaxpr, tuple(nonzeros))

@weakref_lru_cache
@source_info_util.reset_name_stack()
def _linearize_jaxpr(
jaxpr: core.ClosedJaxpr,
nonzeros: tuple[bool, ...]
Expand Down Expand Up @@ -192,7 +193,8 @@ def direct_linearize(traceable: lu.WrappedFun,
linearize_trace = LinearizeTrace(parent_trace, tangent_trace, tag=tag)
tracers = [LinearizeTracer(linearize_trace, p, t) for p, t in zip(primals, tangents)]
tracers = [t.full_lower() for t in tracers]
with core.set_current_trace(linearize_trace, check_leaks=True):
with (core.set_current_trace(linearize_trace, check_leaks=True),
source_info_util.transform_name_stack('jvp')):
if has_aux:
ans, aux = traceable.call_wrapped(*tracers)
aux_primals = [x.primal
Expand Down Expand Up @@ -587,6 +589,10 @@ def __init__(self, parent_trace, tangent_trace, tag=None):
self.tag = core.TraceTag() if tag is None else tag
self.parent_trace = parent_trace
self.tangent_trace = tangent_trace
self._name_stack_prefix_len = len(source_info_util.current_name_stack())

def _name_stack_suffix(self):
return source_info_util.current_name_stack()[self._name_stack_prefix_len:]

def to_primal_tangent_pair(self, val):
if isinstance(val, LinearizeTracer) and val._trace.tag is self.tag:
Expand All @@ -605,7 +611,8 @@ def process_primitive(self, primitive, args, params):
with core.set_current_trace(self.parent_trace):
primal_out, tangent_nzs_out, residuals, linearized = lin(
tangent_nzs, *primals_in, **params)
with core.set_current_trace(self.tangent_trace):
with (core.set_current_trace(self.tangent_trace),
source_info_util.set_name_stack(self._name_stack_suffix())):
tangent_out = linearized(residuals, *tangents_in)
if primitive.multiple_results:
return [maybe_linearize_tracer(self, x, nz, t)
Expand Down
3 changes: 2 additions & 1 deletion tests/name_stack_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,8 @@ def test_grad_should_add_jvp_and_transpose_to_call_jaxpr(self):
@jax.jit
def f(x):
with jax.named_scope('bar'):
return jnp.sin(x)
# return jnp.sin(x)
return jax.lax.sin(x)
jaxpr = jax.make_jaxpr(f)(1.).jaxpr
jaxpr_param = 'jaxpr'

Expand Down
Loading