Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add merge rotation pass #1162

Merged
merged 60 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
0d85eaf
Add pattern
rmoyard Sep 30, 2024
248ab5e
Update structure
rmoyard Sep 30, 2024
9ccb01a
Update
rmoyard Oct 3, 2024
3e5d1d1
Working draft
rmoyard Oct 3, 2024
c4f46ba
Update
rmoyard Oct 3, 2024
c9d56b6
Merge branch 'main' into merge_rotations
rmoyard Oct 7, 2024
7286abc
renamed to `ChainedNamedHermitianOpRewritePattern`
paul0403 Oct 7, 2024
e1f54e5
Add test
rmoyard Oct 7, 2024
4f3f9e2
Merge branch 'merge_rotations' of https://github.com/PennyLaneAI/cata…
rmoyard Oct 7, 2024
043a5f0
MLIR test: CRY switch qubits
rmoyard Oct 7, 2024
1f87943
add pattern
paul0403 Oct 7, 2024
1458937
preprocess with cse pass so we can check param SSA values;
paul0403 Oct 7, 2024
b8b9913
tests
paul0403 Oct 7, 2024
43fcce7
format
paul0403 Oct 7, 2024
799d96e
test with explicit rotation angles
paul0403 Oct 7, 2024
b60519e
test with different explicit params
paul0403 Oct 7, 2024
84c758c
cano test
rmoyard Oct 7, 2024
e3447ef
Update
rmoyard Oct 7, 2024
a43d9ac
changelog
paul0403 Oct 7, 2024
05f5a5b
Initial draft multiRZ
rmoyard Oct 7, 2024
bc5faaf
Typo
rmoyard Oct 7, 2024
485c6d6
ctrl gates
paul0403 Oct 8, 2024
9dcfdc6
Merge remote-tracking branch 'origin/main' into cancel_inverse_adjoint
paul0403 Oct 8, 2024
90ddcf1
remove template type in parent getter (a value will have just one def…
paul0403 Oct 8, 2024
b9acfe6
factor out a parent gate verifier analysis, so it can be reused with …
paul0403 Oct 8, 2024
13eb094
add all test cases for ctrl
paul0403 Oct 9, 2024
91beb68
Merge remote-tracking branch 'origin/main' into cancel_inverse_adjoint
paul0403 Oct 9, 2024
ebb5169
make the named hermitian pattern use the common analysis as well
paul0403 Oct 9, 2024
5c77fb3
one more test
paul0403 Oct 9, 2024
0823116
Merge remote-tracking branch 'origin/main' into cancel_inverse_adjoint
paul0403 Oct 9, 2024
9ae9ba7
follow include order guideline
paul0403 Oct 9, 2024
738c96a
`verified` --> `succeeded`
paul0403 Oct 9, 2024
e78cca1
move namecheck before wire verification
paul0403 Oct 10, 2024
93b1ed1
Merge remote-tracking branch 'origin/cancel_inverse_adjoint' into mer…
rmoyard Oct 10, 2024
6324fd0
Add analysis integration
rmoyard Oct 10, 2024
a6b8424
MultiRz case
rmoyard Oct 10, 2024
b24703a
Split verifier into a "normal" one and an aggressive one.
paul0403 Oct 10, 2024
2f95838
use aggressive for named gates
paul0403 Oct 10, 2024
d092540
Merge remote-tracking branch 'origin/main' into cancel_inverse_adjoint
paul0403 Oct 10, 2024
9bc658b
add multirz
paul0403 Oct 10, 2024
1fe2622
Merge remote-tracking branch 'origin/main' into cancel_inverse_adjoint
paul0403 Oct 10, 2024
552fae2
Merge remote-tracking branch 'origin/main' into cancel_inverse_adjoint
paul0403 Oct 11, 2024
9a7800a
Merge branch 'cancel_inverse_adjoint' into merge_rotations
rmoyard Oct 11, 2024
0660b68
changelog
paul0403 Oct 11, 2024
a7cb5af
change aggressive name to VerifyParentGateAndNameAnalysis
paul0403 Oct 11, 2024
b54fcb7
Merge branch 'cancel_inverse_adjoint' into merge_rotations
rmoyard Oct 11, 2024
02dc927
changelog grammar
paul0403 Oct 11, 2024
7ba892f
Add multirz test
rmoyard Oct 11, 2024
4fb926f
Merge branch 'cancel_inverse_adjoint' into merge_rotations
rmoyard Oct 11, 2024
4ae4281
Update doc
rmoyard Oct 11, 2024
9cee8c9
Merge branch 'merge_rotations' of https://github.com/PennyLaneAI/cata…
rmoyard Oct 11, 2024
1b970da
Update mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp
rmoyard Oct 11, 2024
ca6df1b
Update
rmoyard Oct 11, 2024
f394838
Merge branch 'merge_rotations' of https://github.com/PennyLaneAI/cata…
rmoyard Oct 11, 2024
090d285
Merge branch 'main' into merge_rotations
rmoyard Oct 11, 2024
d0789f0
Merge branch 'merge_rotations' of https://github.com/PennyLaneAI/cata…
rmoyard Oct 11, 2024
bd0ddb6
Update
rmoyard Oct 11, 2024
01cd3f5
Remove erase
rmoyard Oct 11, 2024
def5b3a
Update
rmoyard Oct 11, 2024
0213ae7
Pylint
rmoyard Oct 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,43 @@
Available MLIR passes are now documented and available within the
[catalyst.passes module documentation](https://docs.pennylane.ai/projects/catalyst/en/stable/code/__init__.html#module-catalyst.passes).

* A peephole merge rotations pass is now available in MLIR. It can be added to `catalyst.passes.pipeline`, or the
Python function `catalyst.passes.merge_rotations` can be directly called on a `QNode`.
[(#1162)](https://github.com/PennyLaneAI/catalyst/pull/1162)

Using the pipeline, one can run:

```python
from catalys.passes import pipeline

my_passes = {
"merge_rotations": {}
}

@qjit(circuit_transform_pipeline=my_passes)
@qml.qnode(qml.device("lightning.qubit", wires=1))
def g(x: float):
qml.RX(x, wires=0)
qml.RX(x, wires=0)
qml.Hadamard(wires=0)
return qml.expval(qml.PauliZ(0))
```

Using the python function, one can run:

```python
from catalys.passes import merge_rotations

@qjit
@merge_rotations
@qml.qnode(qml.device("lightning.qubit", wires=1))
def g(x: float):
qml.RX(x, wires=0)
qml.RX(x, wires=0)
qml.Hadamard(wires=0)
return qml.expval(qml.PauliZ(0))
```

* Catalyst Autograph now supports updating a single index or a slice of JAX arrays using Python's
array assignment operator syntax.
[(#769)](https://github.com/PennyLaneAI/catalyst/pull/769)
Expand Down
83 changes: 82 additions & 1 deletion frontend/catalyst/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,87 @@ def wrapper(*args, **kwrags):
return fn_clone


def merge_rotations(fn=None):
"""
Specify that the ``-merge-rotations`` MLIR compiler pass
for merging roations (peephole) will be applied.

The full list of supported gates are as follows:

:class:`qml.RX <pennylane.RX>`,
:class:`qml.CRX <pennylane.CRX>`,
:class:`qml.RY <pennylane.RY>`,
:class:`qml.CRY <pennylane.CRY>`,
:class:`qml.RZ <pennylane.RZ>`,
:class:`qml.CRZ <pennylane.CRZ>`,
:class:`qml.PhaseShift <pennylane.PhaseShift>`,
:class:`qml.ControlledPhaseShift <pennylane.ControlledPhaseShift>`,
:class:`qml.Rot <pennylane.Rot>`,
:class:`qml.CRot <pennylane.CRot>`,
:class:`qml.MultiRZ <pennylane.MultiRZ>`.


.. note::

Unlike PennyLane :doc:`circuit transformations <introduction/compiling_circuits>`,
the QNode itself will not be changed or transformed by applying these
decorators.

As a result, circuit inspection tools such as :func:`~.draw` will continue
to display the circuit as written in Python.

Args:
fn (QNode): the QNode to apply the cancel inverses compiler pass to

Returns:
~.QNode:

**Example**

In this example the three :class:`qml.RX <pennylane.RX>` will be merged in a single
one with the sum of angles as parameter.

.. code-block:: python

from catalyst.debug import get_compilation_stage
from catalyst.passes import merge_rotations

dev = qml.device("lightning.qubit", wires=1)

@qjit(keep_intermediate=True)
@merge_rotations
@qml.qnode(dev)
def circuit(x: float):
qml.RX(x, wires=0)
qml.RX(0.1, wires=0)
qml.RX(x**2, wires=0)
return qml.expval(qml.PauliZ(0))

>>> circuit(0.54)
Array(0.5965506257017892, dtype=float64)
"""
if not isinstance(fn, qml.QNode):
raise TypeError(f"A QNode is expected, got the classical function {fn}")

funcname = fn.__name__
wrapped_qnode_function = fn.func
uniquer = str(_rename_to_unique())

def wrapper(*args, **kwrags):
if EvaluationContext.is_tracing():
apply_registered_pass_p.bind(
pass_name="merge-rotations",
options=f"func-name={funcname}" + "_merge_rotations" + uniquer,
)
return wrapped_qnode_function(*args, **kwrags)

fn_clone = copy.copy(fn)
fn_clone.func = wrapper
fn_clone.__name__ = funcname + "_merge_rotations" + uniquer

return fn_clone


## IMPL and helpers ##
# pylint: disable=missing-function-docstring
class _PipelineNameUniquer:
Expand All @@ -332,7 +413,7 @@ def _rename_to_unique():


def _API_name_to_pass_name():
return {"cancel_inverses": "remove-chained-self-inverse", "merge_rotations": "merge-rotation"}
return {"cancel_inverses": "remove-chained-self-inverse", "merge_rotations": "merge-rotations"}


def _inject_transform_named_sequence():
Expand Down
97 changes: 82 additions & 15 deletions frontend/test/lit/test_peephole_optimizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

from catalyst import qjit
from catalyst.debug import get_compilation_stage
from catalyst.passes import cancel_inverses, pipeline
from catalyst.passes import cancel_inverses, merge_rotations, pipeline


def flush_peephole_opted_mlir_to_iostream(QJIT):
Expand Down Expand Up @@ -86,7 +86,7 @@ def test_pipeline_lowering():
"""
my_pipeline = {
"cancel_inverses": {},
"merge_rotations": {"my-option": "aloha"},
"merge_rotations": {},
}

@qjit(keep_intermediate=True)
Expand All @@ -104,14 +104,14 @@ def test_pipeline_lowering_workflow(x):
# CHECK: pass_name=remove-chained-self-inverse
# CHECK: ]
# CHECK: _:AbstractTransformMod() = apply_registered_pass[
# CHECK: options=func-name=test_pipeline_lowering_workflow_transformed0 my-option=aloha
# CHECK: pass_name=merge-rotation
# CHECK: options=func-name=test_pipeline_lowering_workflow_transformed0
# CHECK: pass_name=merge-rotations
# CHECK: ]
print_jaxpr(test_pipeline_lowering_workflow, 1.2)

# CHECK: transform.named_sequence @__transform_main
# CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "remove-chained-self-inverse" to {{%.+}} {options = "func-name=test_pipeline_lowering_workflow_transformed0"}
# CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "merge-rotation" to {{%.+}} {options = "func-name=test_pipeline_lowering_workflow_transformed0 my-option=aloha"}
# CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "merge-rotations" to {{%.+}} {options = "func-name=test_pipeline_lowering_workflow_transformed0"}
# CHECK-NEXT: transform.yield
print_mlir(test_pipeline_lowering_workflow, 1.2)

Expand Down Expand Up @@ -160,13 +160,13 @@ def test_pipeline_lowering_keep_original_workflow(x):
# CHECK: ]
# CHECK: _:AbstractTransformMod() = apply_registered_pass[
# CHECK: options=func-name=f_transformed0
# CHECK: pass_name=merge-rotation
# CHECK: pass_name=merge-rotations
# CHECK: ]
print_jaxpr(test_pipeline_lowering_keep_original_workflow, 1.2)

# CHECK: transform.named_sequence @__transform_main
# CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "remove-chained-self-inverse" to {{%.+}} {options = "func-name=f_transformed0"}
# CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "merge-rotation" to {{%.+}} {options = "func-name=f_transformed0"}
# CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "merge-rotations" to {{%.+}} {options = "func-name=f_transformed0"}
# CHECK-NEXT: transform.yield
print_mlir(test_pipeline_lowering_keep_original_workflow, 1.2)

Expand Down Expand Up @@ -223,23 +223,23 @@ def h(x):
# CHECK: ]
# CHECK: _:AbstractTransformMod() = apply_registered_pass[
# CHECK: options=func-name=g_transformed0
# CHECK: pass_name=merge-rotation
# CHECK: pass_name=merge-rotations
# CHECK: ]
# CHECK: _:AbstractTransformMod() = apply_registered_pass[
# CHECK: options=func-name=h_transformed1
# CHECK: pass_name=remove-chained-self-inverse
# CHECK: ]
# CHECK: _:AbstractTransformMod() = apply_registered_pass[
# CHECK: options=func-name=h_transformed1
# CHECK: pass_name=merge-rotation
# CHECK: pass_name=merge-rotations
# CHECK: ]
print_jaxpr(global_wf)

# CHECK: transform.named_sequence @__transform_main
# CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "remove-chained-self-inverse" to {{%.+}} {options = "func-name=g_transformed0"}
# CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "merge-rotation" to {{%.+}} {options = "func-name=g_transformed0"}
# CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "merge-rotations" to {{%.+}} {options = "func-name=g_transformed0"}
# CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "remove-chained-self-inverse" to {{%.+}} {options = "func-name=h_transformed1"}
# CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "merge-rotation" to {{%.+}} {options = "func-name=h_transformed1"}
# CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "merge-rotations" to {{%.+}} {options = "func-name=h_transformed1"}
# CHECK-NEXT: transform.yield
print_mlir(global_wf)

Expand Down Expand Up @@ -301,20 +301,20 @@ def h(x):
# CHECK: ]
# CHECK: _:AbstractTransformMod() = apply_registered_pass[
# CHECK: options=func-name=g_transformed1
# CHECK: pass_name=merge-rotation
# CHECK: pass_name=merge-rotations
# CHECK: ]
# CHECK: _:AbstractTransformMod() = apply_registered_pass[
# CHECK: options=func-name=h_transformed0
# CHECK-NOT: pass_name=remove-chained-self-inverse
# CHECK: pass_name=merge-rotation
# CHECK: pass_name=merge-rotations
# CHECK: ]
print_jaxpr(global_wf)

# CHECK: transform.named_sequence @__transform_main
# CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "remove-chained-self-inverse" to {{%.+}} {options = "func-name=g_transformed1"}
# CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "merge-rotation" to {{%.+}} {options = "func-name=g_transformed1"}
# CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "merge-rotations" to {{%.+}} {options = "func-name=g_transformed1"}
# CHECK-NOT: {{%.+}} = transform.apply_registered_pass "remove-chained-self-inverse" to {{%.+}} {options = "func-name=h_transformed0"}
# CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "merge-rotation" to {{%.+}} {options = "func-name=h_transformed0"}
# CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "merge-rotations" to {{%.+}} {options = "func-name=h_transformed0"}
# CHECK-NEXT: transform.yield
print_mlir(global_wf)

Expand Down Expand Up @@ -563,3 +563,70 @@ def test_cancel_inverses_keep_original_workflow2():


test_cancel_inverses_keep_original()


#
# merge_rotations
#


def test_merge_rotations_tracing_and_lowering():
"""
Test merge_rotations during tracing and lowering
"""

@qjit
def test_merge_rotations_tracing_and_lowering_workflow(xx: float):

@merge_rotations
@qml.qnode(qml.device("lightning.qubit", wires=1))
def f(x: float):
qml.RX(x, wires=0)
qml.RX(x, wires=0)
qml.Hadamard(wires=0)
return qml.expval(qml.PauliZ(0))

@merge_rotations
@qml.qnode(qml.device("lightning.qubit", wires=1))
def g(x: float):
qml.RX(x, wires=0)
qml.RX(x, wires=0)
qml.Hadamard(wires=0)
return qml.expval(qml.PauliZ(0))

@qml.qnode(qml.device("lightning.qubit", wires=1))
def h(x: float):
qml.RX(x, wires=0)
qml.RX(x, wires=0)
qml.Hadamard(wires=0)
return qml.expval(qml.PauliZ(0))

_f = f(xx)
_g = g(xx)
_h = h(xx)
return _f, _g, _h

# CHECK: transform_named_sequence
# CHECK: _:AbstractTransformMod() = apply_registered_pass[
# CHECK: options=func-name=f_merge_rotations0
# CHECK: pass_name=merge-rotations
# CHECK: ]
# CHECK: _:AbstractTransformMod() = apply_registered_pass[
# CHECK: options=func-name=g_merge_rotations1
# CHECK: pass_name=merge-rotations
# CHECK: ]
# CHECK-NOT: _:AbstractTransformMod() = apply_registered_pass[
# CHECK-NOT: options=func-name=h_merge_rotations
# CHECK-NOT: pass_name=merge-rotations
print_jaxpr(test_merge_rotations_tracing_and_lowering_workflow, 1.1)

# CHECK: module @test_merge_rotations_tracing_and_lowering_workflow
# CHECK: transform.named_sequence @__transform_main
# CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "merge-rotations" to {{%.+}} {options = "func-name=f_merge_rotations0"}
# CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "merge-rotations" to {{%.+}} {options = "func-name=g_merge_rotations1"}
# CHECK-NOT: {{%.+}} = transform.apply_registered_pass "merge-rotations" to {{%.+}} {options = "func-name=h_merge_rotations"}
# CHECK-NEXT: transform.yield
print_mlir(test_merge_rotations_tracing_and_lowering_workflow, 1.1)


test_merge_rotations_tracing_and_lowering()
Loading
Loading