Skip to content

[hijax] prototype hijax pieces #28781

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 15, 2025
Merged

[hijax] prototype hijax pieces #28781

merged 1 commit into from
May 15, 2025

Conversation

mattjj
Copy link
Collaborator

@mattjj mattjj commented May 15, 2025

This shouldn't affect any existing behaviors, and it shouldn't regress tracing time noticeably.

The main problem we're trying to solve is grad-of-jit where the jit abstracts a Box. In that case, we must pass a reference to the abstracted box to the backward pass. To do that, we make it part of the AD jaxpr. And to avoid polluting all jaxprs, we introduce a new (currently AD-specific) hijax jaxpr, which has more types/primitives than the existing (lojax) jaxprs.

The main implementation ideas:

  • each Trace is tagged with a requires_low: bool
  • each Jaxpr
    • is tagged with an is_high: bool, default False but set True while tracing if any hijax primitives are encountered
    • includes an mut_types: dict[Var, HijaxType] indicating final types for type-changing mutable hijax types
  • each AbstractValue is tagged by a mutable: bool which is read to populate mut_types
  • each Primitive
    • has an is_high(**params) -> bool method (depends on params for HOPs)
    • has a to_lojax(*args, **params) method taking and returning hijaxtypes-wrapping-lowtracers
  • in Primitive.bind, we check if prim.is_high(**params) and trace.requires_low, and if so we call prim.to_lojax

We plan to revise the implementation to look like having a ToLojax trace that sits above the default EvalTrace instance at the top level.

Co-implemented with @dougalm. Design explanation forthcoming.

@mattjj mattjj requested a review from yashk2810 May 15, 2025 21:29
@mattjj mattjj force-pushed the hijax branch 2 times, most recently from 0057a57 to 5dac580 Compare May 15, 2025 21:33
@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels May 15, 2025
@mattjj mattjj force-pushed the hijax branch 5 times, most recently from 7199a7e to 4a8830a Compare May 15, 2025 21:58
shouldn't affect existing behaviors, or trace time

The main implementation ideas:
* each Trace is tagged with a `requires_low: bool`
* each Jaxpr
  * is tagged with an `is_high: bool`, default False but set True while tracing
    if any hijax primitives are encountered
  * includes an `mut_types: dict[Var, HijaxType]` indicating final types for
    type-changing mutable hijax types
* each AbstractValue is tagged by a `mutable: bool` which is read to populate
  `mut_types`
* each Primitive
  * has an `is_high(**params) -> bool` method (depends on params for HOPs)
  * has a `to_lojax(*args, **params)` method taking and returning
    hijaxtypes-wrapping-lowtracers
* in `Primitive.bind`, we check if `prim.is_high(**params) and
  trace.requires_low`, and if so we call `prim.to_lojax`

Co-authored-by: Dougal Maclaurin <dougalm@google.com>
@copybara-service copybara-service bot merged commit 5b2f399 into jax-ml:main May 15, 2025
23 checks passed
@mattjj mattjj deleted the hijax branch May 15, 2025 22:59
copybara-service bot pushed a commit that referenced this pull request May 16, 2025
PiperOrigin-RevId: 759648570
copybara-service bot pushed a commit that referenced this pull request May 16, 2025
PiperOrigin-RevId: 759648570
copybara-service bot pushed a commit that referenced this pull request May 16, 2025
PiperOrigin-RevId: 759658054
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
kokoro:force-run pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants