Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
53df1bc
use pass name from transform
albi3ro Oct 24, 2025
cf01501
some udpates
albi3ro Nov 3, 2025
a08905b
make backwards compatible
albi3ro Nov 5, 2025
0cb9476
messed from plxpr up somehow
albi3ro Nov 13, 2025
dd3a655
Merge branch 'main' into pass-pipeline-transform-program
albi3ro Nov 13, 2025
3b2c5b8
more polishing
albi3ro Nov 13, 2025
42b5052
Merge branch 'main' into pass-pipeline-transform-program
albi3ro Nov 14, 2025
e885b91
some test fixes
albi3ro Nov 14, 2025
5c99bba
fix failing test
albi3ro Nov 17, 2025
ffedfe4
see if that fixes the failure
albi3ro Nov 17, 2025
77a211d
oops
albi3ro Nov 17, 2025
c738102
Merge branch 'main' into pass-pipeline-transform-program
albi3ro Nov 17, 2025
a1abd24
[skip ci] starting to test
albi3ro Nov 17, 2025
b6441bf
try using transform instead of passes
albi3ro Nov 18, 2025
838e45f
switch passes to being tranfsorms
albi3ro Nov 18, 2025
b2ca54d
update apply_pass and apply_pass_plugin
albi3ro Nov 18, 2025
f76ac1b
update apply_pass and apply_pass_plugin
albi3ro Nov 18, 2025
ded40ef
Merge branch 'main' into pass-pipeline-transform-program
albi3ro Nov 19, 2025
8b045d0
adding in some tests
albi3ro Nov 19, 2025
ed528ba
black and isort
albi3ro Nov 19, 2025
80c4a33
Merge branch 'main' into pass-pipeline-transform-program
albi3ro Nov 20, 2025
ec38c77
minor fixes
albi3ro Nov 20, 2025
d6477f5
Merge branch 'pass-pipeline-transform-program' into use-transform-pas…
albi3ro Nov 20, 2025
4771069
remove tests
albi3ro Nov 20, 2025
b5595af
Apply suggestion from @albi3ro
albi3ro Nov 20, 2025
579c27b
leave pipeline test in
albi3ro Nov 20, 2025
068e30b
Merge branch 'use-transform-pass-name' of https://github.com/PennyLan…
albi3ro Nov 20, 2025
e01158e
update name of cancel_inverses in test
albi3ro Nov 20, 2025
d37d89b
Try to fix the lit test again
albi3ro Nov 20, 2025
3586dd3
delete test files
albi3ro Nov 20, 2025
9d167c6
Merge branch 'main' into use-transform-pass-name
albi3ro Dec 3, 2025
80e149e
unpin pl branch
albi3ro Dec 3, 2025
3c8734e
Merge branch 'use-transform-pass-name' of https://github.com/PennyLan…
albi3ro Dec 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 38 additions & 59 deletions frontend/catalyst/passes/builtin_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
import copy
import functools
import json
from pennylane import transform

from catalyst.compiler import _options_to_cli_flags, _quantum_opt
from catalyst.passes.pass_api import PassPipelineWrapper
from catalyst.utils.exceptions import CompileError

# pylint: disable=line-too-long, too-many-lines
Expand Down Expand Up @@ -136,7 +136,7 @@ def circuit(x: float):
%2 = quantum.namedobs %out_qubits[ PauliZ] : !quantum.obs
%3 = quantum.expval %2 : f64
"""
return PassPipelineWrapper(qnode, "cancel-inverses")
return transform(pass_name="cancel-inverses")(qnode)


def disentangle_cnot(qnode):
Expand Down Expand Up @@ -225,7 +225,7 @@ def circuit():
%2 = quantum.extract %0[ 1] : !quantum.reg -> !quantum.bit
%out_qubits_0 = quantum.custom "PauliX"() %2 : !quantum.bit
"""
return PassPipelineWrapper(qnode, "disentangle-CNOT")
return transform(pass_name="disentangle-CNOT")(qnode)


def disentangle_swap(qnode):
Expand Down Expand Up @@ -325,7 +325,7 @@ def circuit():
%out_qubits_2:2 = quantum.custom "CNOT"() %out_qubits_1, %out_qubits : !quantum.bit, !quantum.bit
%out_qubits_3:2 = quantum.custom "CNOT"() %out_qubits_2#1, %out_qubits_2#0 : !quantum.bit, !quantum.bit
"""
return PassPipelineWrapper(qnode, "disentangle-SWAP")
return transform(pass_name="disentangle-SWAP")(qnode)


def merge_rotations(qnode):
Expand Down Expand Up @@ -391,7 +391,7 @@ def circuit(x: float):
>>> circuit(0.54)
Array(0.5965506257017892, dtype=float64)
"""
return PassPipelineWrapper(qnode, "merge-rotations")
return transform(pass_name="merge-rotations")(qnode)


def decompose_lowering(qnode):
Expand All @@ -410,7 +410,7 @@ def decompose_lowering(qnode):
// TODO: add example here

"""
return PassPipelineWrapper(qnode, "decompose-lowering") # pragma: no cover
return transform(pass_name="decompose-lowering")(qnode)


def ions_decomposition(qnode): # pragma: nocover
Expand Down Expand Up @@ -532,7 +532,7 @@ def circuit():
%out_qubits_8 = quantum.custom "RY"(%cst_2) %out_qubits_6#1 : !quantum.bit
%out_qubits_9 = quantum.custom "RY"(%cst_2) %out_qubits_7 : !quantum.bit
"""
return PassPipelineWrapper(qnode, "ions-decomposition")
return transform(pass_name="ions-decomposition")(qnode)


def to_ppr(qnode):
Expand Down Expand Up @@ -611,8 +611,7 @@ def circuit():
In the above output, ``PPR-theta-weight`` denotes the type of PPR present in the circuit, where
``theta`` is the PPR angle (:math:`\theta`) and ``weight`` is the PPR weight.
"""
return PassPipelineWrapper(qnode, "to-ppr")

return transform(pass_name="to-ppr")(qnode)

def commute_ppr(qnode=None, *, max_pauli_size=0):
R"""
Expand Down Expand Up @@ -701,8 +700,7 @@ def circuit():
if qnode is None:
return functools.partial(commute_ppr, max_pauli_size=max_pauli_size)

commute_ppr_pass = {"commute_ppr": {"max-pauli-size": max_pauli_size}}
return PassPipelineWrapper(qnode, commute_ppr_pass)
return transform(pass_name="commute-ppr")(qnode, max_pauli_size=max_pauli_size)


def merge_ppr_ppm(qnode=None, *, max_pauli_size=0):
Expand Down Expand Up @@ -782,8 +780,7 @@ def circuit():
if qnode is None:
return functools.partial(merge_ppr_ppm, max_pauli_size=max_pauli_size)

merge_ppr_ppm_pass = {"merge_ppr_ppm": {"max-pauli-size": max_pauli_size}}
return PassPipelineWrapper(qnode, merge_ppr_ppm_pass)
return transform(pass_name="merge-ppr-ppm")(qnode, max_pauli_size=max_pauli_size)


def ppr_to_ppm(qnode=None, *, decompose_method="pauli-corrected", avoid_y_measure=False):
Expand Down Expand Up @@ -882,19 +879,13 @@ def circuit():
:math:`P(\tfrac{\pi}{2}) = \exp(-iP\tfrac{\pi}{2}) = P`. Pauli operators can be commuted to the
end of the circuit and absorbed into terminal measurements.
"""
passes = {
"ppr_to_ppm": {
"decompose-method": decompose_method,
"avoid-y-measure": avoid_y_measure,
},
}

if qnode is None:
return functools.partial(
ppr_to_ppm, decompose_method=decompose_method, avoid_y_measure=avoid_y_measure
)

return PassPipelineWrapper(qnode, passes)
return transform(pass_name="ppr-to-ppm")(qnode, decompose_method=decompose_method, avoid_y_measure=avoid_y_measure)


def ppm_compilation(
Expand Down Expand Up @@ -998,13 +989,6 @@ def circuit():
``max_pauli_size`` qubits (here, ``max_pauli_size = 2``), that commutation or merge would be
skipped.
"""
passes = {
"ppm-compilation": {
"decompose-method": decompose_method,
"avoid-y-measure": avoid_y_measure,
"max-pauli-size": max_pauli_size,
}
}

if qnode is None:
return functools.partial(
Expand All @@ -1014,8 +998,7 @@ def circuit():
max_pauli_size=max_pauli_size,
)

return PassPipelineWrapper(qnode, passes)

return transform(pass_name="ppm-compilation")(qnode, decompose_method=decompose_method, avoid_y_measure=avoid_y_measure, max_pauli_size=max_pauli_size)

def ppm_specs(fn):
R"""
Expand Down Expand Up @@ -1088,34 +1071,31 @@ def loop(i):
. . .

"""

if fn.mlir_module is not None:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just switching to using a guard clause here. No actual code changes, but made it easier to scan the source code and see what was happening.

# aot mode
new_options = copy.copy(fn.compile_options)
if new_options.pipelines is None:
raise CompileError("No pipeline found")

# add ppm-spec pass at the end to existing pipeline
_, pass_list = new_options.pipelines[0] # first pipeline runs the user passes
# check if ppm-specs is already in the pass list
if "ppm-specs" not in pass_list: # pragma: nocover
pass_list.append("ppm-specs")

new_options = _options_to_cli_flags(new_options)
raw_result = _quantum_opt(*new_options, [], stdin=str(fn.mlir_module))

try:
return json.loads(
raw_result[: raw_result.index("module")]
) # remove MLIR starting with substring "module..."
except Exception as e: # pragma: nocover
raise CompileError(
"Invalid json format encountered in ppm_specs. "
f"Expected valid JSON but got {raw_result[: raw_result.index('module')]}"
) from e

else:
if fn.mlir_module is None:
raise NotImplementedError("PPM passes only support AOT (Ahead-Of-Time) compilation mode.")
# aot mode
new_options = copy.copy(fn.compile_options)
if new_options.pipelines is None:
raise CompileError("No pipeline found")

# add ppm-spec pass at the end to existing pipeline
_, pass_list = new_options.pipelines[0] # first pipeline runs the user passes
# check if ppm-specs is already in the pass list
if "ppm-specs" not in pass_list: # pragma: nocover
pass_list.append("ppm-specs")

new_options = _options_to_cli_flags(new_options)
raw_result = _quantum_opt(*new_options, [], stdin=str(fn.mlir_module))

try:
return json.loads(
raw_result[: raw_result.index("module")]
) # remove MLIR starting with substring "module..."
except Exception as e: # pragma: nocover
raise CompileError(
"Invalid json format encountered in ppm_specs. "
f"Expected valid JSON but got {raw_result[: raw_result.index('module')]}"
) from e


def reduce_t_depth(qnode):
Expand Down Expand Up @@ -1194,8 +1174,7 @@ def circuit():
%9:3 = qec.ppr ["X", "X", "Y"](8) %8#0, %8#1, %8#2:!quantum.bit, !quantum.bit, !quantum.bit
. . .
"""

return PassPipelineWrapper(qnode, "reduce-t-depth")
return transform(pass_name="reduce-t-depth")(qnode)


def ppr_to_mbqc(qnode):
Expand Down Expand Up @@ -1284,4 +1263,4 @@ def circuit():
...

"""
return PassPipelineWrapper(qnode, "ppr-to-mbqc")
return transform(pass_name="ppr-to-mbqc")(qnode)
4 changes: 2 additions & 2 deletions frontend/catalyst/passes/pass_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def module():
"""

def decorator(qnode):
return PassPipelineWrapper(qnode, pass_name, *flags, **valued_options)
return qml.transform(pass_name=pass_name)(qnode, *flags, **valued_options)

return decorator

Expand Down Expand Up @@ -244,7 +244,7 @@ def module():
raise FileNotFoundError(f"File '{path_to_plugin}' does not exist.")

def decorator(qnode):
return PassPipelineWrapper(qnode, pass_name, *flags, **valued_options)
return qml.transform(pass_name=pass_name)(qnode, *flags, **valued_options)

return decorator

Expand Down
43 changes: 20 additions & 23 deletions frontend/test/pytest/test_peephole_optimizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,29 +163,6 @@ def classical_func():
):
pipeline({})(classical_func)

with pytest.raises(
TypeError,
match="A QNode is expected, got the classical function",
):
cancel_inverses(classical_func)

with pytest.raises(
TypeError,
match="A QNode is expected, got the classical function",
):
merge_rotations(classical_func)

with pytest.raises(
TypeError,
match="A QNode is expected, got the classical function",
):
disentangle_cnot(classical_func)

with pytest.raises(
TypeError,
match="A QNode is expected, got the classical function",
):
disentangle_swap(classical_func)

test_passes_not_on_qnode()

Expand All @@ -211,6 +188,26 @@ def test_chained_apply_passes_workflow(x: float):
assert "merge-rotations" in mlir


def test_chained_transforms():
"""
Test that chained transforms are present in the transform passes.
"""

@qjit
@qml.transforms.merge_rotations
@qml.transforms.cancel_inverses
@qml.qnode(qml.device("lightning.qubit", wires=2))
def test_chained_apply_passes_workflow(x: float):
qml.Hadamard(wires=[1])
qml.RX(x, wires=[0])
qml.RX(-x, wires=[0])
qml.Hadamard(wires=[1])
return qml.expval(qml.PauliY(wires=0))

assert "cancel-inverses" in test_chained_apply_passes_workflow.mlir
assert "merge-rotations" in test_chained_apply_passes_workflow.mlir


def test_disentangle_passes():
"""
Test that disentangle passes are present in the transform passes
Expand Down
Loading