Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
53df1bc
use pass name from transform
albi3ro Oct 24, 2025
cf01501
some udpates
albi3ro Nov 3, 2025
a08905b
make backwards compatible
albi3ro Nov 5, 2025
0cb9476
messed from plxpr up somehow
albi3ro Nov 13, 2025
dd3a655
Merge branch 'main' into pass-pipeline-transform-program
albi3ro Nov 13, 2025
3b2c5b8
more polishing
albi3ro Nov 13, 2025
42b5052
Merge branch 'main' into pass-pipeline-transform-program
albi3ro Nov 14, 2025
e885b91
some test fixes
albi3ro Nov 14, 2025
5c99bba
fix failing test
albi3ro Nov 17, 2025
ffedfe4
see if that fixes the failure
albi3ro Nov 17, 2025
77a211d
oops
albi3ro Nov 17, 2025
c738102
Merge branch 'main' into pass-pipeline-transform-program
albi3ro Nov 17, 2025
a1abd24
[skip ci] starting to test
albi3ro Nov 17, 2025
ded40ef
Merge branch 'main' into pass-pipeline-transform-program
albi3ro Nov 19, 2025
8b045d0
adding in some tests
albi3ro Nov 19, 2025
ed528ba
black and isort
albi3ro Nov 19, 2025
80c4a33
Merge branch 'main' into pass-pipeline-transform-program
albi3ro Nov 20, 2025
ec38c77
minor fixes
albi3ro Nov 20, 2025
1cd4f18
fix test
albi3ro Nov 20, 2025
f197df9
try and fix this lit test yet again
albi3ro Nov 20, 2025
4a46934
Update frontend/test/pytest/test_transform_pass_name.py
albi3ro Nov 20, 2025
ce8478b
Update frontend/test/pytest/test_transform_pass_name.py
albi3ro Nov 20, 2025
e7db76a
Apply suggestion from @albi3ro
albi3ro Nov 20, 2025
9096ada
Merge branch 'main' into pass-pipeline-transform-program
albi3ro Nov 20, 2025
da0b9ce
Merge branch 'main' into pass-pipeline-transform-program
albi3ro Dec 2, 2025
b6f3c8c
update version, remove unnecessary test
albi3ro Dec 3, 2025
a253ec7
Merge branch 'main' into pass-pipeline-transform-program
albi3ro Dec 3, 2025
ce51804
fix test failure
albi3ro Dec 3, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .dep-versions
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ enzyme=v0.0.203

# For a custom PL version, update the package version here and at
# 'doc/requirements.txt'
pennylane=0.44.0-dev42
pennylane=0.44.0-dev44

# For a custom LQ/LK version, update the package version here and at
# 'doc/requirements.txt'
Expand Down
4 changes: 4 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@

<h3>Improvements 🛠</h3>

* Catalyst can now use the new `pass_name` property of pennylane transform objects. Passes can now
be created using `qml.transform(pass_name=pass_name)` instead of `PassPipelineWrapper`.
[(#2149](https://github.com/PennyLaneAI/catalyst/pull/2149)

* An error is now raised if a transform is applied inside a QNode when program capture is enabled.
[(#2256)](https://github.com/PennyLaneAI/catalyst/pull/2256)

Expand Down
2 changes: 1 addition & 1 deletion doc/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,4 @@ lxml_html_clean
--extra-index-url https://test.pypi.org/simple/
pennylane-lightning-kokkos==0.44.0-dev16
pennylane-lightning==0.44.0-dev16
pennylane==0.44.0-dev42
pennylane==0.44.0-dev44
96 changes: 51 additions & 45 deletions frontend/catalyst/from_plxpr/from_plxpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,9 @@
from pennylane.capture.expand_transforms import ExpandTransformsInterpreter
from pennylane.capture.primitives import jacobian_prim as pl_jac_prim
from pennylane.capture.primitives import transform_prim
from pennylane.transforms import cancel_inverses as pl_cancel_inverses
from pennylane.transforms import commute_controlled as pl_commute_controlled
from pennylane.transforms import decompose as pl_decompose
from pennylane.transforms import merge_amplitude_embedding as pl_merge_amplitude_embedding
from pennylane.transforms import merge_rotations as pl_merge_rotations
from pennylane.transforms import single_qubit_fusion as pl_single_qubit_fusion
from pennylane.transforms import unitary_to_rot as pl_unitary_to_rot

Expand All @@ -48,7 +46,6 @@
qdealloc_p,
quantum_kernel_p,
)
from catalyst.passes.pass_api import Pass
from catalyst.utils.patching import Patcher

from .qfunc_interpreter import PLxPRToQuantumJaxprInterpreter
Expand Down Expand Up @@ -286,7 +283,9 @@ def handle_qnode(
# Fallback to the legacy decomposition if the graph-based decomposition failed
if not graph_succeeded:
# Remove the decompose-lowering pass from the pipeline
self._pass_pipeline = [p for p in self._pass_pipeline if p.name != "decompose-lowering"]
self._pass_pipeline = [
p for p in self._pass_pipeline if p.pass_name != "decompose-lowering"
]
closed_jaxpr = _apply_compiler_decompose_to_plxpr(
inner_jaxpr=closed_jaxpr.jaxpr,
consts=closed_jaxpr.consts,
Expand Down Expand Up @@ -334,11 +333,9 @@ def calling_convention(*args):
# otherwise their value will be None. The second value indicates if the transform
# requires decomposition to be supported by Catalyst.
transforms_to_passes = {
pl_cancel_inverses: ("cancel-inverses", False),
pl_commute_controlled: (None, False),
pl_decompose: (None, False),
pl_merge_amplitude_embedding: (None, True),
pl_merge_rotations: ("merge-rotations", False),
pl_single_qubit_fusion: (None, False),
pl_unitary_to_rot: (None, False),
}
Expand All @@ -349,6 +346,47 @@ def register_transform(pl_transform, pass_name, decomposition):
transforms_to_passes[pl_transform] = (pass_name, decomposition)


def _handle_decompose_transform(self, inner_jaxpr, consts, non_const_args, tkwargs):
if not self.requires_decompose_lowering:
self.requires_decompose_lowering = True
else:
raise NotImplementedError("Multiple decomposition transforms are not yet supported.")

next_eval = copy(self)
# Update the decompose_gateset to be used by the quantum kernel primitive
# TODO: we originally wanted to treat decompose_gateset as a queue of
# gatesets to be used by the decompose-lowering pass at MLIR
# but this requires a C++ implementation of the graph-based decomposition
# which doesn't exist yet.
next_eval.decompose_tkwargs = tkwargs

# Note. We don't perform the compiler-specific decomposition here
# to be able to support multiple decomposition transforms
# and collect all the required gatesets
# as well as being able to support other transforms in between.

# The compiler specific transformation will be performed
# in the qnode handler.

# Add the decompose-lowering pass to the start of the pipeline
t = qml.transform(pass_name="decompose-lowering")
pass_container = qml.transforms.core.TransformContainer(t)
next_eval._pass_pipeline.insert(0, pass_container)

# We still need to construct and solve the graph based on
# the current jaxpr based on the current gateset
# but we don't rewrite the jaxpr at this stage.

# gds_interpreter = DecompRuleInterpreter(*targs, **tkwargs)

# def gds_wrapper(*args):
# return gds_interpreter.eval(inner_jaxpr, consts, *args)

# final_jaxpr = jax.make_jaxpr(gds_wrapper)(*args)
# return self.eval(final_jaxpr.jaxpr, consts, *non_const_args)
return next_eval.eval(inner_jaxpr, consts, *non_const_args)


# pylint: disable=too-many-arguments
@WorkflowInterpreter.register_primitive(transform_prim)
def handle_transform(
Expand All @@ -375,45 +413,11 @@ def handle_transform(
and transform._plxpr_transform.__name__ == "decompose_plxpr_to_plxpr"
and qml.decomposition.enabled_graph()
):
# Handle the conversion from plxpr to Catalyst jaxpr for a PL transform.
if not self.requires_decompose_lowering:
self.requires_decompose_lowering = True
else:
raise NotImplementedError("Multiple decomposition transforms are not yet supported.")

next_eval = copy(self)
# Update the decompose_gateset to be used by the quantum kernel primitive
# TODO: we originally wanted to treat decompose_gateset as a queue of
# gatesets to be used by the decompose-lowering pass at MLIR
# but this requires a C++ implementation of the graph-based decomposition
# which doesn't exist yet.
next_eval.decompose_tkwargs = tkwargs

# Note. We don't perform the compiler-specific decomposition here
# to be able to support multiple decomposition transforms
# and collect all the required gatesets
# as well as being able to support other transforms in between.

# The compiler specific transformation will be performed
# in the qnode handler.

# Add the decompose-lowering pass to the start of the pipeline
next_eval._pass_pipeline.insert(0, Pass("decompose-lowering"))
return _handle_decompose_transform(self, inner_jaxpr, consts, non_const_args, tkwargs)

# We still need to construct and solve the graph based on
# the current jaxpr based on the current gateset
# but we don't rewrite the jaxpr at this stage.

# gds_interpreter = DecompRuleInterpreter(*targs, **tkwargs)

# def gds_wrapper(*args):
# return gds_interpreter.eval(inner_jaxpr, consts, *args)

# final_jaxpr = jax.make_jaxpr(gds_wrapper)(*args)
# return self.eval(final_jaxpr.jaxpr, consts, *non_const_args)
return next_eval.eval(inner_jaxpr, consts, *non_const_args)

catalyst_pass_name = transforms_to_passes.get(transform, (None,))[0]
catalyst_pass_name = transform.pass_name
if catalyst_pass_name is None:
catalyst_pass_name = transforms_to_passes.get(transform, (None,))[0]
if catalyst_pass_name is None:
# Use PL's ExpandTransformsInterpreter to expand this and any embedded
# transform according to PL rules. It works by overriding the primitive
Expand All @@ -435,7 +439,9 @@ def wrapper(*args):

# Apply the corresponding Catalyst pass counterpart
next_eval = copy(self)
next_eval._pass_pipeline.insert(0, Pass(catalyst_pass_name, *targs, **tkwargs))
t = qml.transform(pass_name=catalyst_pass_name)
bound_pass = qml.transforms.core.TransformContainer(t, args=targs, kwargs=tkwargs)
next_eval._pass_pipeline.insert(0, bound_pass)
return next_eval.eval(inner_jaxpr, consts, *non_const_args)


Expand Down
21 changes: 18 additions & 3 deletions frontend/catalyst/jax_primitives_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,16 @@ def __exit__(self, exc_type, exc_val, exc_tb):
self.ctx.module_context = self.old_module_context


def _lowered_options(args, kwargs):
lowered_options = {}
for arg in args:
lowered_options[str(arg)] = get_mlir_attribute_from_pyval(True)
for option, value in kwargs.items():
mlir_option = str(option).replace("_", "-")
lowered_options[mlir_option] = get_mlir_attribute_from_pyval(value)
return lowered_options


def transform_named_sequence_lowering(jax_ctx: mlir.LoweringRuleContext, pipeline):
"""Generate a transform module embedded in the current module and schedule
the transformations in pipeline"""
Expand Down Expand Up @@ -364,11 +374,16 @@ def transform_named_sequence_lowering(jax_ctx: mlir.LoweringRuleContext, pipelin
with ir.InsertionPoint(bb_named_sequence):
target = bb_named_sequence.arguments[0]
for _pass in pipeline:
options = _pass.get_options()
if isinstance(_pass, qml.transforms.core.TransformContainer):
options = _lowered_options(_pass.args, _pass.kwargs)
name = _pass.pass_name
else:
options = _pass.get_options()
name = _pass.name
apply_registered_pass_op = ApplyRegisteredPassOp(
result=transform_mod_type,
target=target,
pass_name=_pass.name,
pass_name=name,
options=options,
dynamic_options={},
)
Expand All @@ -380,7 +395,7 @@ def transform_named_sequence_lowering(jax_ctx: mlir.LoweringRuleContext, pipelin
is_xdsl_pass,
)

if is_xdsl_pass(_pass.name):
if is_xdsl_pass(name):
uses_xdsl_passes = True
apply_registered_pass_op.operation.attributes["catalyst.xdsl_pass"] = (
ir.UnitAttr.get()
Expand Down
36 changes: 30 additions & 6 deletions frontend/catalyst/qfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,18 +285,23 @@ def __call__(self, *args, **kwargs):

assert isinstance(self, qml.QNode)

new_transform_program, new_pipeline = _extract_passes(self.transform_program)
# Update the qnode with peephole pipeline
pass_pipeline = kwargs.pop("pass_pipeline", [])
pass_pipeline = dictionary_to_list_of_passes(pass_pipeline)
old_pipeline = kwargs.pop("pass_pipeline", None)
processed_old_pipeline = tuple(dictionary_to_list_of_passes(old_pipeline))
pass_pipeline = processed_old_pipeline + new_pipeline
new_qnode = copy(self)
# pylint: disable=attribute-defined-outside-init, protected-access
new_qnode._transform_program = new_transform_program

# Mid-circuit measurement configuration/execution
fn_result = configure_mcm_and_try_one_shot(self, args, kwargs, pass_pipeline)
fn_result = configure_mcm_and_try_one_shot(new_qnode, args, kwargs, pass_pipeline)

# If the qnode is failed to execute as one-shot, fn_result will be None
if fn_result is not None:
return fn_result

new_device = copy(self.device)
new_device = copy(new_qnode.device)
qjit_device = QJITDevice(new_device)

static_argnums = kwargs.pop("static_argnums", ())
Expand All @@ -307,11 +312,11 @@ def __call__(self, *args, **kwargs):

def _eval_quantum(*args, **kwargs):
trace_result = trace_quantum_function(
self.func,
new_qnode.func,
qjit_device,
args,
kwargs,
self,
new_qnode,
static_argnums,
debug_info,
)
Expand Down Expand Up @@ -655,3 +660,22 @@ def wrap_single_shot_qnode(*_):
return _finalize_output(out, ctx)

return one_shot_wrapper


def _extract_passes(transform_program):
"""Extract transforms with pass names from the end of the TransformProgram."""
tape_transforms = []
pass_pipeline = []
i = len(transform_program)
for t in reversed(transform_program):
if t.pass_name is None:
break
i -= 1
pass_pipeline = transform_program[i:]
tape_transforms = transform_program[:i]
for t in tape_transforms:
if t.transform is None:
raise ValueError(
f"{t} without a tape definition occurs before tape transform {tape_transforms[-1]}."
)
return qml.transforms.core.TransformProgram(tape_transforms), tuple(pass_pipeline)
1 change: 1 addition & 0 deletions frontend/test/lit/test_decomposition.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022-2025 Xanadu Quantum Technologies Inc.

Check notice on line 1 in frontend/test/lit/test_decomposition.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/test/lit/test_decomposition.py#L1

Missing module docstring (missing-module-docstring)
import os
import pathlib
import platform
Expand Down Expand Up @@ -46,6 +46,7 @@
error_msg = str(e)
if (
"Unsupported type annotation None for parameter pauli_word" in error_msg
or "Unsupported type annotation <class 'str'> for parameter pauli_word" in error_msg
or "index is out of bounds for axis" in error_msg
):
print(f"# SKIPPED {test_func.__name__}: PauliRot type annotation issue")
Expand Down
36 changes: 36 additions & 0 deletions frontend/test/lit/test_peephole_optimizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,42 @@ def test_pipeline_lowering_workflow(x):
test_pipeline_lowering()


def test_transform_lowering():
"""
Basic pipeline lowering on one qnode.
"""

@qjit(keep_intermediate=True)
@qml.transforms.merge_rotations
@qml.transforms.cancel_inverses
@qml.qnode(qml.device("lightning.qubit", wires=2))
def test_pipeline_lowering_workflow(x):
qml.RX(x, wires=[0])
qml.Hadamard(wires=[1])
qml.Hadamard(wires=[1])
return qml.expval(qml.PauliY(wires=0))

# CHECK: pipeline=(<cancel_inverses((), {})>, <merge_rotations((), {})>)
print_jaxpr(test_pipeline_lowering_workflow, 1.2)

# CHECK: transform.named_sequence @__transform_main
# CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "cancel-inverses" to {{%.+}}
# CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "merge-rotations" to {{%.+}}
# CHECK-NEXT: transform.yield
print_mlir(test_pipeline_lowering_workflow, 1.2)

# CHECK: {{%.+}} = call @test_pipeline_lowering_workflow_0(
# CHECK: func.func public @test_pipeline_lowering_workflow_0(
# CHECK: {{%.+}} = quantum.custom "RX"({{%.+}}) {{%.+}} : !quantum.bit
# CHECK-NOT: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit
# CHECK-NOT: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit
test_pipeline_lowering_workflow(42.42)
flush_peephole_opted_mlir_to_iostream(test_pipeline_lowering_workflow)


test_transform_lowering()


def test_pipeline_lowering_keep_original():
"""
Test when the pipelined qnode and the original qnode are both used,
Expand Down
21 changes: 20 additions & 1 deletion frontend/test/pytest/from_plxpr/test_capture_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1048,7 +1048,7 @@ def circuit(x: float):
assert jnp.allclose(circuit(0.1), capture_result)

@pytest.mark.usefixtures("use_capture")
def test_pass_with_options(self, backend):
def test_pass_with_options_patch(self, backend):
"""Test the integration for a circuit with a pass that takes in options."""

@qml.transform
Expand All @@ -1071,6 +1071,25 @@ def captured_circuit():
in capture_mlir
)

@pytest.mark.usefixtures("use_capture")
def test_pass_with_options(self, backend):
"""Test the integration for a circuit with a pass that takes in options."""

my_pass = qml.transform(pass_name="my-pass")

@qjit(target="mlir")
@partial(my_pass, my_option="my_option_value", my_other_option=False)
@qml.qnode(qml.device(backend, wires=1))
def captured_circuit():
return qml.expval(qml.PauliZ(0))

capture_mlir = captured_circuit.mlir
assert 'transform.apply_registered_pass "my-pass"' in capture_mlir
assert (
'with options = {"my-option" = "my_option_value", "my-other-option" = false}'
in capture_mlir
)

def test_transform_cancel_inverses_workflow(self, backend):
"""Test the integration for a circuit with a 'cancel_inverses' transform."""

Expand Down
Loading