-
Notifications
You must be signed in to change notification settings - Fork 30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Frontend] Use custom lowering rules #1152
[Frontend] Use custom lowering rules #1152
Conversation
2dff6d1
to
584c15d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Neat! Thanks for improving the code base.
There is an additional registration we perform for types, not sure if those have an equivalent or not.
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #1152 +/- ##
==========================================
- Coverage 97.87% 97.86% -0.01%
==========================================
Files 76 76
Lines 10850 10810 -40
Branches 1283 1281 -2
==========================================
- Hits 10619 10579 -40
Misses 179 179
Partials 52 52 ☔ View full report in Codecov by Sentry. |
Co-authored-by: David Ittah <dime10@users.noreply.github.com>
Not yet, but I'll create an issue in JAX and maybe it will get some traction? |
Co-authored-by: David Ittah <dime10@users.noreply.github.com>
**Context:** Using [`mlir.register_lowering`](https://github.com/jax-ml/jax/blob/ae86ef16c7a03409cb444da3d477b2adb8134e6f/jax/_src/interpreters/mlir.py#L814-L825) [will modify the global variable `_lowerings` or `platform_specific_lowerings`](https://github.com/jax-ml/jax/blob/ae86ef16c7a03409cb444da3d477b2adb8134e6f/jax/_src/interpreters/mlir.py#L810-L812). To avoid this, JAX provides the[ `LoweringParameters`](https://github.com/jax-ml/jax/blob/ae86ef16c7a03409cb444da3d477b2adb8134e6f/jax/_src/interpreters/mlir.py#L630-L649) structure to pass custom lowering rules. When lowering rules are passed using [`LoweringParameters`, these global variables will not be rewritten.](https://github.com/jax-ml/jax/blob/ae86ef16c7a03409cb444da3d477b2adb8134e6f/jax/_src/interpreters/mlir.py#L1745-L1751) **Description of the Change:** Use LoweringParameters instead of register_lowering **Benefits:** Avoid rewriting global data structures from JAX. **Possible Drawbacks:** None. **Related GitHub Issues:** --------- Co-authored-by: David Ittah <dime10@users.noreply.github.com>
Context: Using
mlir.register_lowering
will modify the global variable_lowerings
orplatform_specific_lowerings
. To avoid this, JAX provides theLoweringParameters
structure to pass custom lowering rules. When lowering rules are passed usingLoweringParameters
, these global variables will not be rewritten.Description of the Change: Use LoweringParameters instead of register_lowering
Benefits: Avoid rewriting global data structures from JAX.
Possible Drawbacks: None.
Related GitHub Issues: