-
Notifications
You must be signed in to change notification settings - Fork 30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Autograph automatic conversion of in-place operator-based array updates #1143
Conversation
08ee23f
to
d947d4f
Compare
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #1143 +/- ##
==========================================
+ Coverage 97.87% 97.89% +0.01%
==========================================
Files 75 76 +1
Lines 10801 10858 +57
Branches 1268 1281 +13
==========================================
+ Hits 10572 10629 +57
Misses 179 179
Partials 50 50 ☔ View full report in Codecov by Sentry. |
1. simplified the ag_primitives to only have a update_item_with_op which dispatches to the correct arithmatic operation. 2. Added support for updating a slice using assignment operators 3. Reflected the changes to the documentation 4. Added a test for slice update 5. Improved existing tests to reduce the calls to qjit. 6. Addressed some minor review comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apart from the code coverage warnings, LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
Context:
#717 added support for converting in-place array updates (
arr[i] = x
) into the equivalent JAX traceable code (arr.at[i].set(x)
). This PR extends that support to operator assignment array updates.Description of the Change:
AugAssign
ast nodes assigning to a single index or a slice subscript to calls toupdate_item_with_op
update_item_with_op
method that map to the correspondingjax.numpy.ndarray.at
equivalent methods for JAX arrays and the normal Python operator assignment otherwisetransform_ast
inCatalystTransformer
to invoke the new converterBenefits: We can use
arr[i] += x
instead ofarr.at[i].add(x)
.Possible Drawbacks: It would be cleaner to have the new converter live in the DiastaticMalt project.
Related GitHub Issues: #757
Based on the solution presented in this PR: #769
Note that this PR was originally implemented externally by #769. This PR aims to revisit that PR.
[sc-60318]