Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Program Capture] Capture & execute qml.grad in plxpr #6120

Merged
merged 42 commits into from
Sep 9, 2024
Merged

Conversation

dwierichs
Copy link
Contributor

@dwierichs dwierichs commented Aug 21, 2024

Context:
The new qml.capture module does not support differentiation yet.

Description of the Change:
This PR takes the first step towards differentiability in plxpr.
It adds the capability of capturing qml.grad as a "nested jaxpr" primitive.
When executing the captured program, qml.grad is essentially changed to jax.grad, because executing Autograd autodifferentiation within the Jaxpr ecosystem is not sensible.

Benefits:
Capture first differentiation instructions

Possible Drawbacks:
The current implementation requires a jvp construction for every evaluation of a QNode gradient. This means that this JVP function is reconstructed for every evaluation call, if I'm not mistaken, making the code significantly less performant with capture than without. Of course, the longer term plan is to process the plxpr into lower-level code by lowering the grad primitive itself, in which case this problem goes away.
A similar redundancy is implemented in QNode: Whenever a qnode primitive is evaluated, a new QNode is created (and only ever evaluated once). This disables caching, for example, unless a cache is passed around explicitly.

Related GitHub Issues:

[sc-71858]

Copy link
Contributor

Hello. You may have forgotten to update the changelog!
Please edit doc/releases/changelog-dev.md with:

  • A one-to-two sentence description of the change. You may include a small working example for new features.
  • A link back to this PR.
  • Your name (or GitHub username) in the contributors section.

Copy link

codecov bot commented Aug 21, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 99.60%. Comparing base (df63953) to head (db99812).
Report is 1 commits behind head on master.

Additional details and impacted files
@@           Coverage Diff           @@
##           master    #6120   +/-   ##
=======================================
  Coverage   99.60%   99.60%           
=======================================
  Files         445      446    +1     
  Lines       42364    42423   +59     
=======================================
+ Hits        42198    42257   +59     
  Misses        166      166           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@dwierichs dwierichs changed the title Capture qml.grad in plxpr [Program Capture] Capture qml.grad in plxpr Aug 21, 2024
@dwierichs dwierichs changed the title [Program Capture] Capture qml.grad in plxpr [Program Capture] Capture & execute qml.grad in plxpr Aug 22, 2024
@dwierichs dwierichs added the review-ready 👌 PRs which are ready for review by someone from the core team. label Aug 22, 2024
@mudit2812 mudit2812 self-requested a review August 22, 2024 14:22
@albi3ro albi3ro self-requested a review August 22, 2024 14:37
Copy link
Contributor Author

@dwierichs dwierichs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just leaving some (hopefully helpful) comments.

pennylane/_grad.py Outdated Show resolved Hide resolved
pennylane/_grad.py Show resolved Hide resolved
pennylane/capture/capture_qnode.py Show resolved Hide resolved
pennylane/capture/primitives.py Outdated Show resolved Hide resolved
Copy link
Contributor

@mudit2812 mudit2812 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎉

@dwierichs dwierichs added this pull request to the merge queue Sep 9, 2024
github-merge-queue bot pushed a commit that referenced this pull request Sep 9, 2024
[Program Capture] Capture & execute `qml.grad` in plxpr
@github-merge-queue github-merge-queue bot removed this pull request from the merge queue due to no response for status checks Sep 9, 2024
@dwierichs dwierichs added this pull request to the merge queue Sep 9, 2024
@github-merge-queue github-merge-queue bot removed this pull request from the merge queue due to no response for status checks Sep 9, 2024
@dwierichs dwierichs merged commit 1d220cf into master Sep 9, 2024
37 checks passed
@dwierichs dwierichs deleted the capture-grad branch September 9, 2024 13:06
dwierichs added a commit that referenced this pull request Sep 9, 2024
**Context:**
We're adding support for differentiation in plxpr, also see #6120.

**Description of the Change:**
This PR adds support for `qml.jacobian`, similar to the support for
`qml.grad`.
Note that Pytree support will be needed to allow for multi-argument
derivatives.

**Benefits:**
Capture derivatives of non-scalar functions.

**Possible Drawbacks:**
See discussion around `qml.grad` in #6120.

**Related GitHub Issues:**

[sc-71860]

---------

Co-authored-by: Christina Lee <christina@xanadu.ai>
mudit2812 pushed a commit that referenced this pull request Sep 10, 2024
**Context:**
The new `qml.capture` module does not support differentiation yet.

**Description of the Change:**
This PR takes the first step towards differentiability in plxpr.
It adds the capability of capturing `qml.grad` as a "nested jaxpr"
primitive.
When executing the captured program, `qml.grad` is essentially changed
to `jax.grad`, because executing Autograd autodifferentiation within the
Jaxpr ecosystem is not sensible.

**Benefits:**
Capture first differentiation instructions

**Possible Drawbacks:**
The current implementation requires a `jvp` construction for every
evaluation of a QNode gradient. This means that this JVP function is
reconstructed for every evaluation call, if I'm not mistaken, making the
code significantly less performant with `capture` than without. Of
course, the longer term plan is to process the plxpr into lower-level
code by lowering the `grad` primitive itself, in which case this problem
goes away.
A similar redundancy is implemented in `QNode`: Whenever a `qnode`
primitive is evaluated, a new `QNode` is created (and only ever
evaluated once). This disables caching, for example, unless a cache is
passed around explicitly.

**Related GitHub Issues:**

[sc-71858]

---------

Co-authored-by: Christina Lee <christina@xanadu.ai>
mudit2812 pushed a commit that referenced this pull request Sep 10, 2024
**Context:**
We're adding support for differentiation in plxpr, also see #6120.

**Description of the Change:**
This PR adds support for `qml.jacobian`, similar to the support for
`qml.grad`.
Note that Pytree support will be needed to allow for multi-argument
derivatives.

**Benefits:**
Capture derivatives of non-scalar functions.

**Possible Drawbacks:**
See discussion around `qml.grad` in #6120.

**Related GitHub Issues:**

[sc-71860]

---------

Co-authored-by: Christina Lee <christina@xanadu.ai>
mudit2812 pushed a commit that referenced this pull request Sep 12, 2024
**Context:**
The new `qml.capture` module does not support differentiation yet.

**Description of the Change:**
This PR takes the first step towards differentiability in plxpr.
It adds the capability of capturing `qml.grad` as a "nested jaxpr"
primitive.
When executing the captured program, `qml.grad` is essentially changed
to `jax.grad`, because executing Autograd autodifferentiation within the
Jaxpr ecosystem is not sensible.

**Benefits:**
Capture first differentiation instructions

**Possible Drawbacks:**
The current implementation requires a `jvp` construction for every
evaluation of a QNode gradient. This means that this JVP function is
reconstructed for every evaluation call, if I'm not mistaken, making the
code significantly less performant with `capture` than without. Of
course, the longer term plan is to process the plxpr into lower-level
code by lowering the `grad` primitive itself, in which case this problem
goes away.
A similar redundancy is implemented in `QNode`: Whenever a `qnode`
primitive is evaluated, a new `QNode` is created (and only ever
evaluated once). This disables caching, for example, unless a cache is
passed around explicitly.

**Related GitHub Issues:**

[sc-71858]

---------

Co-authored-by: Christina Lee <christina@xanadu.ai>
mudit2812 pushed a commit that referenced this pull request Sep 12, 2024
**Context:**
We're adding support for differentiation in plxpr, also see #6120.

**Description of the Change:**
This PR adds support for `qml.jacobian`, similar to the support for
`qml.grad`.
Note that Pytree support will be needed to allow for multi-argument
derivatives.

**Benefits:**
Capture derivatives of non-scalar functions.

**Possible Drawbacks:**
See discussion around `qml.grad` in #6120.

**Related GitHub Issues:**

[sc-71860]

---------

Co-authored-by: Christina Lee <christina@xanadu.ai>
dwierichs added a commit that referenced this pull request Sep 15, 2024
…jacobian` (#6134)

**Context:**
#6120 and #6127 add support to capture `qml.grad` and `qml.jacobian` in
plxpr. Once captured, they dispatch to `jax.grad` and `jax.jacobian`.

**Description of the Change:**
This PR adds support for pytree inputs and outputs of the differentiated
functions, similar to #6081.
For this, it extends the internal class `FlatFn` by the extra
functionality to turn the wrapper into a `*flat_args -> *flat_outputs`
function, instead of a `*pytree_args -> *flat_outputs` function.

**Benefits:**
Pytree support 🌳 

**Possible Drawbacks:**

**Related GitHub Issues:**

[sc-70930]
[sc-71862]

---------

Co-authored-by: Christina Lee <christina@xanadu.ai>
Co-authored-by: Mudit Pandey <mudit.pandey@xanadu.ai>
mudit2812 added a commit that referenced this pull request Sep 16, 2024
…jacobian` (#6134)

**Context:**
#6120 and #6127 add support to capture `qml.grad` and `qml.jacobian` in
plxpr. Once captured, they dispatch to `jax.grad` and `jax.jacobian`.

**Description of the Change:**
This PR adds support for pytree inputs and outputs of the differentiated
functions, similar to #6081.
For this, it extends the internal class `FlatFn` by the extra
functionality to turn the wrapper into a `*flat_args -> *flat_outputs`
function, instead of a `*pytree_args -> *flat_outputs` function.

**Benefits:**
Pytree support 🌳 

**Possible Drawbacks:**

**Related GitHub Issues:**

[sc-70930]
[sc-71862]

---------

Co-authored-by: Christina Lee <christina@xanadu.ai>
Co-authored-by: Mudit Pandey <mudit.pandey@xanadu.ai>
mudit2812 pushed a commit that referenced this pull request Sep 16, 2024
**Context:**
We're adding support for differentiation in plxpr, also see #6120.

**Description of the Change:**
This PR adds support for `qml.jacobian`, similar to the support for
`qml.grad`.
Note that Pytree support will be needed to allow for multi-argument
derivatives.

**Benefits:**
Capture derivatives of non-scalar functions.

**Possible Drawbacks:**
See discussion around `qml.grad` in #6120.

**Related GitHub Issues:**

[sc-71860]

---------

Co-authored-by: Christina Lee <christina@xanadu.ai>
mudit2812 added a commit that referenced this pull request Sep 16, 2024
…jacobian` (#6134)

**Context:**
#6120 and #6127 add support to capture `qml.grad` and `qml.jacobian` in
plxpr. Once captured, they dispatch to `jax.grad` and `jax.jacobian`.

**Description of the Change:**
This PR adds support for pytree inputs and outputs of the differentiated
functions, similar to #6081.
For this, it extends the internal class `FlatFn` by the extra
functionality to turn the wrapper into a `*flat_args -> *flat_outputs`
function, instead of a `*pytree_args -> *flat_outputs` function.

**Benefits:**
Pytree support 🌳 

**Possible Drawbacks:**

**Related GitHub Issues:**

[sc-70930]
[sc-71862]

---------

Co-authored-by: Christina Lee <christina@xanadu.ai>
Co-authored-by: Mudit Pandey <mudit.pandey@xanadu.ai>
mudit2812 added a commit that referenced this pull request Sep 18, 2024
…jacobian` (#6134)

**Context:**
#6120 and #6127 add support to capture `qml.grad` and `qml.jacobian` in
plxpr. Once captured, they dispatch to `jax.grad` and `jax.jacobian`.

**Description of the Change:**
This PR adds support for pytree inputs and outputs of the differentiated
functions, similar to #6081.
For this, it extends the internal class `FlatFn` by the extra
functionality to turn the wrapper into a `*flat_args -> *flat_outputs`
function, instead of a `*pytree_args -> *flat_outputs` function.

**Benefits:**
Pytree support 🌳 

**Possible Drawbacks:**

**Related GitHub Issues:**

[sc-70930]
[sc-71862]

---------

Co-authored-by: Christina Lee <christina@xanadu.ai>
Co-authored-by: Mudit Pandey <mudit.pandey@xanadu.ai>
mudit2812 added a commit that referenced this pull request Sep 23, 2024
…jacobian` (#6134)

**Context:**
#6120 and #6127 add support to capture `qml.grad` and `qml.jacobian` in
plxpr. Once captured, they dispatch to `jax.grad` and `jax.jacobian`.

**Description of the Change:**
This PR adds support for pytree inputs and outputs of the differentiated
functions, similar to #6081.
For this, it extends the internal class `FlatFn` by the extra
functionality to turn the wrapper into a `*flat_args -> *flat_outputs`
function, instead of a `*pytree_args -> *flat_outputs` function.

**Benefits:**
Pytree support 🌳 

**Possible Drawbacks:**

**Related GitHub Issues:**

[sc-70930]
[sc-71862]

---------

Co-authored-by: Christina Lee <christina@xanadu.ai>
Co-authored-by: Mudit Pandey <mudit.pandey@xanadu.ai>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
review-ready 👌 PRs which are ready for review by someone from the core team.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants