[hijax] prototype hijax pieces #28781
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This shouldn't affect any existing behaviors, and it shouldn't regress tracing time noticeably.
The main problem we're trying to solve is
grad
-of-jit
where thejit
abstracts aBox
. In that case, we must pass a reference to the abstracted box to the backward pass. To do that, we make it part of the AD jaxpr. And to avoid polluting all jaxprs, we introduce a new (currently AD-specific) hijax jaxpr, which has more types/primitives than the existing (lojax) jaxprs.The main implementation ideas:
requires_low: bool
is_high: bool
, default False but set True while tracing if any hijax primitives are encounteredmut_types: dict[Var, HijaxType]
indicating final types for type-changing mutable hijax typesmutable: bool
which is read to populatemut_types
is_high(**params) -> bool
method (depends on params for HOPs)to_lojax(*args, **params)
method taking and returning hijaxtypes-wrapping-lowtracersPrimitive.bind
, we check ifprim.is_high(**params) and trace.requires_low
, and if so we callprim.to_lojax
We plan to revise the implementation to look like having a ToLojax trace that sits above the default EvalTrace instance at the top level.
Co-implemented with @dougalm. Design explanation forthcoming.