-
Notifications
You must be signed in to change notification settings - Fork 595
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Program Capture] Capture & execute qml.grad
in plxpr
#6120
Conversation
Hello. You may have forgotten to update the changelog!
|
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #6120 +/- ##
=======================================
Coverage 99.60% 99.60%
=======================================
Files 445 446 +1
Lines 42364 42423 +59
=======================================
+ Hits 42198 42257 +59
Misses 166 166 ☔ View full report in Codecov by Sentry. |
qml.grad
in plxprqml.grad
in plxpr
qml.grad
in plxprqml.grad
in plxpr
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just leaving some (hopefully helpful) comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🎉
[Program Capture] Capture & execute `qml.grad` in plxpr
**Context:** We're adding support for differentiation in plxpr, also see #6120. **Description of the Change:** This PR adds support for `qml.jacobian`, similar to the support for `qml.grad`. Note that Pytree support will be needed to allow for multi-argument derivatives. **Benefits:** Capture derivatives of non-scalar functions. **Possible Drawbacks:** See discussion around `qml.grad` in #6120. **Related GitHub Issues:** [sc-71860] --------- Co-authored-by: Christina Lee <christina@xanadu.ai>
**Context:** The new `qml.capture` module does not support differentiation yet. **Description of the Change:** This PR takes the first step towards differentiability in plxpr. It adds the capability of capturing `qml.grad` as a "nested jaxpr" primitive. When executing the captured program, `qml.grad` is essentially changed to `jax.grad`, because executing Autograd autodifferentiation within the Jaxpr ecosystem is not sensible. **Benefits:** Capture first differentiation instructions **Possible Drawbacks:** The current implementation requires a `jvp` construction for every evaluation of a QNode gradient. This means that this JVP function is reconstructed for every evaluation call, if I'm not mistaken, making the code significantly less performant with `capture` than without. Of course, the longer term plan is to process the plxpr into lower-level code by lowering the `grad` primitive itself, in which case this problem goes away. A similar redundancy is implemented in `QNode`: Whenever a `qnode` primitive is evaluated, a new `QNode` is created (and only ever evaluated once). This disables caching, for example, unless a cache is passed around explicitly. **Related GitHub Issues:** [sc-71858] --------- Co-authored-by: Christina Lee <christina@xanadu.ai>
**Context:** We're adding support for differentiation in plxpr, also see #6120. **Description of the Change:** This PR adds support for `qml.jacobian`, similar to the support for `qml.grad`. Note that Pytree support will be needed to allow for multi-argument derivatives. **Benefits:** Capture derivatives of non-scalar functions. **Possible Drawbacks:** See discussion around `qml.grad` in #6120. **Related GitHub Issues:** [sc-71860] --------- Co-authored-by: Christina Lee <christina@xanadu.ai>
**Context:** The new `qml.capture` module does not support differentiation yet. **Description of the Change:** This PR takes the first step towards differentiability in plxpr. It adds the capability of capturing `qml.grad` as a "nested jaxpr" primitive. When executing the captured program, `qml.grad` is essentially changed to `jax.grad`, because executing Autograd autodifferentiation within the Jaxpr ecosystem is not sensible. **Benefits:** Capture first differentiation instructions **Possible Drawbacks:** The current implementation requires a `jvp` construction for every evaluation of a QNode gradient. This means that this JVP function is reconstructed for every evaluation call, if I'm not mistaken, making the code significantly less performant with `capture` than without. Of course, the longer term plan is to process the plxpr into lower-level code by lowering the `grad` primitive itself, in which case this problem goes away. A similar redundancy is implemented in `QNode`: Whenever a `qnode` primitive is evaluated, a new `QNode` is created (and only ever evaluated once). This disables caching, for example, unless a cache is passed around explicitly. **Related GitHub Issues:** [sc-71858] --------- Co-authored-by: Christina Lee <christina@xanadu.ai>
**Context:** We're adding support for differentiation in plxpr, also see #6120. **Description of the Change:** This PR adds support for `qml.jacobian`, similar to the support for `qml.grad`. Note that Pytree support will be needed to allow for multi-argument derivatives. **Benefits:** Capture derivatives of non-scalar functions. **Possible Drawbacks:** See discussion around `qml.grad` in #6120. **Related GitHub Issues:** [sc-71860] --------- Co-authored-by: Christina Lee <christina@xanadu.ai>
…jacobian` (#6134) **Context:** #6120 and #6127 add support to capture `qml.grad` and `qml.jacobian` in plxpr. Once captured, they dispatch to `jax.grad` and `jax.jacobian`. **Description of the Change:** This PR adds support for pytree inputs and outputs of the differentiated functions, similar to #6081. For this, it extends the internal class `FlatFn` by the extra functionality to turn the wrapper into a `*flat_args -> *flat_outputs` function, instead of a `*pytree_args -> *flat_outputs` function. **Benefits:** Pytree support 🌳 **Possible Drawbacks:** **Related GitHub Issues:** [sc-70930] [sc-71862] --------- Co-authored-by: Christina Lee <christina@xanadu.ai> Co-authored-by: Mudit Pandey <mudit.pandey@xanadu.ai>
…jacobian` (#6134) **Context:** #6120 and #6127 add support to capture `qml.grad` and `qml.jacobian` in plxpr. Once captured, they dispatch to `jax.grad` and `jax.jacobian`. **Description of the Change:** This PR adds support for pytree inputs and outputs of the differentiated functions, similar to #6081. For this, it extends the internal class `FlatFn` by the extra functionality to turn the wrapper into a `*flat_args -> *flat_outputs` function, instead of a `*pytree_args -> *flat_outputs` function. **Benefits:** Pytree support 🌳 **Possible Drawbacks:** **Related GitHub Issues:** [sc-70930] [sc-71862] --------- Co-authored-by: Christina Lee <christina@xanadu.ai> Co-authored-by: Mudit Pandey <mudit.pandey@xanadu.ai>
**Context:** We're adding support for differentiation in plxpr, also see #6120. **Description of the Change:** This PR adds support for `qml.jacobian`, similar to the support for `qml.grad`. Note that Pytree support will be needed to allow for multi-argument derivatives. **Benefits:** Capture derivatives of non-scalar functions. **Possible Drawbacks:** See discussion around `qml.grad` in #6120. **Related GitHub Issues:** [sc-71860] --------- Co-authored-by: Christina Lee <christina@xanadu.ai>
…jacobian` (#6134) **Context:** #6120 and #6127 add support to capture `qml.grad` and `qml.jacobian` in plxpr. Once captured, they dispatch to `jax.grad` and `jax.jacobian`. **Description of the Change:** This PR adds support for pytree inputs and outputs of the differentiated functions, similar to #6081. For this, it extends the internal class `FlatFn` by the extra functionality to turn the wrapper into a `*flat_args -> *flat_outputs` function, instead of a `*pytree_args -> *flat_outputs` function. **Benefits:** Pytree support 🌳 **Possible Drawbacks:** **Related GitHub Issues:** [sc-70930] [sc-71862] --------- Co-authored-by: Christina Lee <christina@xanadu.ai> Co-authored-by: Mudit Pandey <mudit.pandey@xanadu.ai>
…jacobian` (#6134) **Context:** #6120 and #6127 add support to capture `qml.grad` and `qml.jacobian` in plxpr. Once captured, they dispatch to `jax.grad` and `jax.jacobian`. **Description of the Change:** This PR adds support for pytree inputs and outputs of the differentiated functions, similar to #6081. For this, it extends the internal class `FlatFn` by the extra functionality to turn the wrapper into a `*flat_args -> *flat_outputs` function, instead of a `*pytree_args -> *flat_outputs` function. **Benefits:** Pytree support 🌳 **Possible Drawbacks:** **Related GitHub Issues:** [sc-70930] [sc-71862] --------- Co-authored-by: Christina Lee <christina@xanadu.ai> Co-authored-by: Mudit Pandey <mudit.pandey@xanadu.ai>
…jacobian` (#6134) **Context:** #6120 and #6127 add support to capture `qml.grad` and `qml.jacobian` in plxpr. Once captured, they dispatch to `jax.grad` and `jax.jacobian`. **Description of the Change:** This PR adds support for pytree inputs and outputs of the differentiated functions, similar to #6081. For this, it extends the internal class `FlatFn` by the extra functionality to turn the wrapper into a `*flat_args -> *flat_outputs` function, instead of a `*pytree_args -> *flat_outputs` function. **Benefits:** Pytree support 🌳 **Possible Drawbacks:** **Related GitHub Issues:** [sc-70930] [sc-71862] --------- Co-authored-by: Christina Lee <christina@xanadu.ai> Co-authored-by: Mudit Pandey <mudit.pandey@xanadu.ai>
Context:
The new
qml.capture
module does not support differentiation yet.Description of the Change:
This PR takes the first step towards differentiability in plxpr.
It adds the capability of capturing
qml.grad
as a "nested jaxpr" primitive.When executing the captured program,
qml.grad
is essentially changed tojax.grad
, because executing Autograd autodifferentiation within the Jaxpr ecosystem is not sensible.Benefits:
Capture first differentiation instructions
Possible Drawbacks:
The current implementation requires a
jvp
construction for every evaluation of a QNode gradient. This means that this JVP function is reconstructed for every evaluation call, if I'm not mistaken, making the code significantly less performant withcapture
than without. Of course, the longer term plan is to process the plxpr into lower-level code by lowering thegrad
primitive itself, in which case this problem goes away.A similar redundancy is implemented in
QNode
: Whenever aqnode
primitive is evaluated, a newQNode
is created (and only ever evaluated once). This disables caching, for example, unless a cache is passed around explicitly.Related GitHub Issues:
[sc-71858]