-
Notifications
You must be signed in to change notification settings - Fork 595
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Program Capture] higher order primitives support pytree input and output #6081
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #6081 +/- ##
==========================================
- Coverage 99.66% 99.65% -0.01%
==========================================
Files 430 431 +1
Lines 41857 41593 -264
==========================================
- Hits 41716 41451 -265
- Misses 141 142 +1 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good. Now that the while loop capture is merged, maybe worth adding pytree capture support for that as well?
Co-authored-by: Mudit Pandey <mudit.pandey@xanadu.ai>
…ylane into capture-pytrees
Co-authored-by: Utkarsh <utkarshazad98@gmail.com> Co-authored-by: Mudit Pandey <mudit.pandey@xanadu.ai>
Co-authored-by: Thomas R. Bromley <49409390+trbromley@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not going to block approval, but could you also add a test for qml.cond
with MCM predicates that accepts pytree arguments? Conditionals with MCMs cannot return anything, so the only thing that would need to be tested is pytree inputs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, just one small non-blocking doubt!
Added. |
Co-authored-by: Utkarsh <utkarshazad98@gmail.com>
…jacobian` (#6134) **Context:** #6120 and #6127 add support to capture `qml.grad` and `qml.jacobian` in plxpr. Once captured, they dispatch to `jax.grad` and `jax.jacobian`. **Description of the Change:** This PR adds support for pytree inputs and outputs of the differentiated functions, similar to #6081. For this, it extends the internal class `FlatFn` by the extra functionality to turn the wrapper into a `*flat_args -> *flat_outputs` function, instead of a `*pytree_args -> *flat_outputs` function. **Benefits:** Pytree support 🌳 **Possible Drawbacks:** **Related GitHub Issues:** [sc-70930] [sc-71862] --------- Co-authored-by: Christina Lee <christina@xanadu.ai> Co-authored-by: Mudit Pandey <mudit.pandey@xanadu.ai>
…jacobian` (#6134) **Context:** #6120 and #6127 add support to capture `qml.grad` and `qml.jacobian` in plxpr. Once captured, they dispatch to `jax.grad` and `jax.jacobian`. **Description of the Change:** This PR adds support for pytree inputs and outputs of the differentiated functions, similar to #6081. For this, it extends the internal class `FlatFn` by the extra functionality to turn the wrapper into a `*flat_args -> *flat_outputs` function, instead of a `*pytree_args -> *flat_outputs` function. **Benefits:** Pytree support 🌳 **Possible Drawbacks:** **Related GitHub Issues:** [sc-70930] [sc-71862] --------- Co-authored-by: Christina Lee <christina@xanadu.ai> Co-authored-by: Mudit Pandey <mudit.pandey@xanadu.ai>
…jacobian` (#6134) **Context:** #6120 and #6127 add support to capture `qml.grad` and `qml.jacobian` in plxpr. Once captured, they dispatch to `jax.grad` and `jax.jacobian`. **Description of the Change:** This PR adds support for pytree inputs and outputs of the differentiated functions, similar to #6081. For this, it extends the internal class `FlatFn` by the extra functionality to turn the wrapper into a `*flat_args -> *flat_outputs` function, instead of a `*pytree_args -> *flat_outputs` function. **Benefits:** Pytree support 🌳 **Possible Drawbacks:** **Related GitHub Issues:** [sc-70930] [sc-71862] --------- Co-authored-by: Christina Lee <christina@xanadu.ai> Co-authored-by: Mudit Pandey <mudit.pandey@xanadu.ai>
…jacobian` (#6134) **Context:** #6120 and #6127 add support to capture `qml.grad` and `qml.jacobian` in plxpr. Once captured, they dispatch to `jax.grad` and `jax.jacobian`. **Description of the Change:** This PR adds support for pytree inputs and outputs of the differentiated functions, similar to #6081. For this, it extends the internal class `FlatFn` by the extra functionality to turn the wrapper into a `*flat_args -> *flat_outputs` function, instead of a `*pytree_args -> *flat_outputs` function. **Benefits:** Pytree support 🌳 **Possible Drawbacks:** **Related GitHub Issues:** [sc-70930] [sc-71862] --------- Co-authored-by: Christina Lee <christina@xanadu.ai> Co-authored-by: Mudit Pandey <mudit.pandey@xanadu.ai>
…jacobian` (#6134) **Context:** #6120 and #6127 add support to capture `qml.grad` and `qml.jacobian` in plxpr. Once captured, they dispatch to `jax.grad` and `jax.jacobian`. **Description of the Change:** This PR adds support for pytree inputs and outputs of the differentiated functions, similar to #6081. For this, it extends the internal class `FlatFn` by the extra functionality to turn the wrapper into a `*flat_args -> *flat_outputs` function, instead of a `*pytree_args -> *flat_outputs` function. **Benefits:** Pytree support 🌳 **Possible Drawbacks:** **Related GitHub Issues:** [sc-70930] [sc-71862] --------- Co-authored-by: Christina Lee <christina@xanadu.ai> Co-authored-by: Mudit Pandey <mudit.pandey@xanadu.ai>
Context:
Similar to jax, we should be able to catch higher order primitives that accept arbitrary pytree inputs and outputs. Currently, our primitives assume flat inputs and flat outputs.
Description of the Change:
Updates
qnode_call
,for_loop
, andcond
to allow arbitrary pytree inputs and outputs.For example,
Benefits:
Easier to cart complicated combinations of variables around. More versatile program capture.
Possible Drawbacks:
If different branches (ie cond) have different pytree structures, we may run into weird behavior.
Related GitHub Issues: