Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[direct-linearize] fix name stack tests #26290

Merged
merged 1 commit into from
Feb 4, 2025

Conversation

mattjj
Copy link
Collaborator

@mattjj mattjj commented Feb 4, 2025

Now this passes:

JAX_USE_DIRECT_LINEARIZE=1 python tests/name_stack_test.py

See #8395 for an explanation of how the name stack should work. There were three changes we needed to replicate current behavior:

  1. push 'jvp' onto the name stack when we trace the function in direct_linearize;
  2. to solve the "doubled naming" problem, set the name stack only to be the appropriate suffix when running the tangent part of the computation in LinearizeTrace.process_primitive;
  3. when doing the jaxpr-to-jaxpr linearize_jaxpr, be sure to reset the name stack (which was provided in trace_to_jaxpr_dynamic, but linearize_jaxpr doesn't use that and instead uses lower-level stuff).

Co-authored-by: Sharad Vikram <sharadmv@google.com>
@mattjj mattjj requested a review from dougalm February 4, 2025 00:50
@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Feb 4, 2025
@copybara-service copybara-service bot merged commit 363f1e6 into jax-ml:main Feb 4, 2025
22 of 23 checks passed
@mattjj mattjj deleted the linearize-name-stack-fixes-2 branch February 4, 2025 17:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants