Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Program Capture] higher order primitives support pytree input and output #6081

Merged
merged 21 commits into from
Aug 13, 2024

Conversation

albi3ro
Copy link
Contributor

@albi3ro albi3ro commented Aug 7, 2024

Context:

Similar to jax, we should be able to catch higher order primitives that accept arbitrary pytree inputs and outputs. Currently, our primitives assume flat inputs and flat outputs.

Description of the Change:

Updates qnode_call, for_loop, and cond to allow arbitrary pytree inputs and outputs.

For example,

@qml.for_loop(1, 7, 2)
def f(i, x):
    return {"x": i+x["x"]}

x = {"x": 0}
f(x) 
{'x': Array(9, dtype=int32, weak_type=True)}

Benefits:

Easier to cart complicated combinations of variables around. More versatile program capture.

Possible Drawbacks:

If different branches (ie cond) have different pytree structures, we may run into weird behavior.

Related GitHub Issues:

doc/releases/changelog-dev.md Outdated Show resolved Hide resolved
Copy link

codecov bot commented Aug 7, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 99.65%. Comparing base (5a7a163) to head (dd8f722).
Report is 1 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #6081      +/-   ##
==========================================
- Coverage   99.66%   99.65%   -0.01%     
==========================================
  Files         430      431       +1     
  Lines       41857    41593     -264     
==========================================
- Hits        41716    41451     -265     
- Misses        141      142       +1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Contributor

@mudit2812 mudit2812 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. Now that the while loop capture is merged, maybe worth adding pytree capture support for that as well?

pennylane/capture/explanations.md Outdated Show resolved Hide resolved
pennylane/capture/explanations.md Outdated Show resolved Hide resolved
pennylane/capture/explanations.md Show resolved Hide resolved
pennylane/compiler/qjit_api.py Show resolved Hide resolved
pennylane/ops/op_math/condition.py Outdated Show resolved Hide resolved
tests/capture/test_capture_cond.py Outdated Show resolved Hide resolved
@albi3ro albi3ro requested a review from mudit2812 August 9, 2024 13:34
pennylane/capture/explanations.md Outdated Show resolved Hide resolved
pennylane/capture/flatfn.py Outdated Show resolved Hide resolved
pennylane/compiler/qjit_api.py Show resolved Hide resolved
Co-authored-by: Utkarsh <utkarshazad98@gmail.com>
Co-authored-by: Mudit Pandey <mudit.pandey@xanadu.ai>
pennylane/capture/explanations.md Outdated Show resolved Hide resolved
pennylane/capture/explanations.md Outdated Show resolved Hide resolved
pennylane/capture/explanations.md Outdated Show resolved Hide resolved
Co-authored-by: Thomas R. Bromley <49409390+trbromley@users.noreply.github.com>
Copy link
Contributor

@mudit2812 mudit2812 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not going to block approval, but could you also add a test for qml.cond with MCM predicates that accepts pytree arguments? Conditionals with MCMs cannot return anything, so the only thing that would need to be tested is pytree inputs.

Copy link
Contributor

@obliviateandsurrender obliviateandsurrender left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, just one small non-blocking doubt!

doc/releases/changelog-dev.md Outdated Show resolved Hide resolved
pennylane/capture/explanations.md Outdated Show resolved Hide resolved
pennylane/compiler/qjit_api.py Show resolved Hide resolved
@albi3ro
Copy link
Contributor Author

albi3ro commented Aug 12, 2024

I'm not going to block approval, but could you also add a test for qml.cond with MCM predicates that accepts pytree arguments? Conditionals with MCMs cannot return anything, so the only thing that would need to be tested is pytree inputs.

Added.

albi3ro and others added 2 commits August 12, 2024 16:59
@albi3ro albi3ro enabled auto-merge (squash) August 13, 2024 12:54
@albi3ro albi3ro merged commit 6f77dce into master Aug 13, 2024
40 checks passed
@albi3ro albi3ro deleted the capture-pytrees branch August 13, 2024 13:30
dwierichs added a commit that referenced this pull request Sep 15, 2024
…jacobian` (#6134)

**Context:**
#6120 and #6127 add support to capture `qml.grad` and `qml.jacobian` in
plxpr. Once captured, they dispatch to `jax.grad` and `jax.jacobian`.

**Description of the Change:**
This PR adds support for pytree inputs and outputs of the differentiated
functions, similar to #6081.
For this, it extends the internal class `FlatFn` by the extra
functionality to turn the wrapper into a `*flat_args -> *flat_outputs`
function, instead of a `*pytree_args -> *flat_outputs` function.

**Benefits:**
Pytree support 🌳 

**Possible Drawbacks:**

**Related GitHub Issues:**

[sc-70930]
[sc-71862]

---------

Co-authored-by: Christina Lee <christina@xanadu.ai>
Co-authored-by: Mudit Pandey <mudit.pandey@xanadu.ai>
mudit2812 added a commit that referenced this pull request Sep 16, 2024
…jacobian` (#6134)

**Context:**
#6120 and #6127 add support to capture `qml.grad` and `qml.jacobian` in
plxpr. Once captured, they dispatch to `jax.grad` and `jax.jacobian`.

**Description of the Change:**
This PR adds support for pytree inputs and outputs of the differentiated
functions, similar to #6081.
For this, it extends the internal class `FlatFn` by the extra
functionality to turn the wrapper into a `*flat_args -> *flat_outputs`
function, instead of a `*pytree_args -> *flat_outputs` function.

**Benefits:**
Pytree support 🌳 

**Possible Drawbacks:**

**Related GitHub Issues:**

[sc-70930]
[sc-71862]

---------

Co-authored-by: Christina Lee <christina@xanadu.ai>
Co-authored-by: Mudit Pandey <mudit.pandey@xanadu.ai>
mudit2812 added a commit that referenced this pull request Sep 16, 2024
…jacobian` (#6134)

**Context:**
#6120 and #6127 add support to capture `qml.grad` and `qml.jacobian` in
plxpr. Once captured, they dispatch to `jax.grad` and `jax.jacobian`.

**Description of the Change:**
This PR adds support for pytree inputs and outputs of the differentiated
functions, similar to #6081.
For this, it extends the internal class `FlatFn` by the extra
functionality to turn the wrapper into a `*flat_args -> *flat_outputs`
function, instead of a `*pytree_args -> *flat_outputs` function.

**Benefits:**
Pytree support 🌳 

**Possible Drawbacks:**

**Related GitHub Issues:**

[sc-70930]
[sc-71862]

---------

Co-authored-by: Christina Lee <christina@xanadu.ai>
Co-authored-by: Mudit Pandey <mudit.pandey@xanadu.ai>
mudit2812 added a commit that referenced this pull request Sep 18, 2024
…jacobian` (#6134)

**Context:**
#6120 and #6127 add support to capture `qml.grad` and `qml.jacobian` in
plxpr. Once captured, they dispatch to `jax.grad` and `jax.jacobian`.

**Description of the Change:**
This PR adds support for pytree inputs and outputs of the differentiated
functions, similar to #6081.
For this, it extends the internal class `FlatFn` by the extra
functionality to turn the wrapper into a `*flat_args -> *flat_outputs`
function, instead of a `*pytree_args -> *flat_outputs` function.

**Benefits:**
Pytree support 🌳 

**Possible Drawbacks:**

**Related GitHub Issues:**

[sc-70930]
[sc-71862]

---------

Co-authored-by: Christina Lee <christina@xanadu.ai>
Co-authored-by: Mudit Pandey <mudit.pandey@xanadu.ai>
mudit2812 added a commit that referenced this pull request Sep 23, 2024
…jacobian` (#6134)

**Context:**
#6120 and #6127 add support to capture `qml.grad` and `qml.jacobian` in
plxpr. Once captured, they dispatch to `jax.grad` and `jax.jacobian`.

**Description of the Change:**
This PR adds support for pytree inputs and outputs of the differentiated
functions, similar to #6081.
For this, it extends the internal class `FlatFn` by the extra
functionality to turn the wrapper into a `*flat_args -> *flat_outputs`
function, instead of a `*pytree_args -> *flat_outputs` function.

**Benefits:**
Pytree support 🌳 

**Possible Drawbacks:**

**Related GitHub Issues:**

[sc-70930]
[sc-71862]

---------

Co-authored-by: Christina Lee <christina@xanadu.ai>
Co-authored-by: Mudit Pandey <mudit.pandey@xanadu.ai>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants