Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
7d393f7
rs-decomp
josephleekl Oct 21, 2025
6189426
update
josephleekl Oct 23, 2025
b3be643
update
josephleekl Oct 24, 2025
ba25e5a
update
josephleekl Oct 24, 2025
2ab894f
update
josephleekl Oct 24, 2025
1947de4
update
josephleekl Oct 27, 2025
45a64b7
update
josephleekl Oct 27, 2025
5402851
remove unnecessary include
josephleekl Oct 28, 2025
8748550
update
josephleekl Oct 29, 2025
e49db23
Merge branch 'main' into rs-decomp
josephleekl Oct 29, 2025
79d6712
dummy impl
josephleekl Oct 29, 2025
b10d42c
remove debug
josephleekl Oct 29, 2025
328a0be
clean
josephleekl Oct 29, 2025
4648c0e
decomp takes theta and epsilon
josephleekl Nov 5, 2025
d3635d1
Merge branch 'main' into rs-decomp
josephleekl Nov 5, 2025
b7e4f75
Merge branch 'main' into rs-decomp
josephleekl Nov 7, 2025
fbde78a
support RZ and updated frontend
josephleekl Nov 7, 2025
9114b32
format
josephleekl Nov 7, 2025
04fb5af
Merge branch 'main' into rs-decomp
josephleekl Nov 7, 2025
e5f2bf4
fix sourceop has more than one input/output
josephleekl Nov 10, 2025
c3cd186
allow direct lowering to ppr
josephleekl Nov 10, 2025
b929fe5
Merge branch 'main' into rs-decomp
josephleekl Nov 10, 2025
828c179
format
josephleekl Nov 10, 2025
5caaa53
clean up
josephleekl Nov 10, 2025
8fb7605
fix mac segfault
josephleekl Nov 11, 2025
f29c9ba
extractop allow dynamic/static indices
josephleekl Nov 11, 2025
6537d4f
Merge branch 'main' into rs-decomp
josephleekl Nov 11, 2025
2afb190
change rsdecomposition to gridsynth
josephleekl Nov 11, 2025
93c04c4
fix missing dialect prefix in xdsl detection
mehrdad2m Nov 11, 2025
d8ba2c0
Merge branch 'main' of https://github.com/PennyLaneAI/catalyst into r…
lazypanda10117 Nov 14, 2025
e39bb99
Merge branch 'main' of https://github.com/PennyLaneAI/catalyst into r…
lazypanda10117 Nov 14, 2025
acf887a
Merge branch 'main' into rs-decomp
josephleekl Nov 19, 2025
183e262
Merge branch 'main' into rs-decomp
josephleekl Nov 20, 2025
8126f9b
update mlir pass
josephleekl Nov 20, 2025
f42e091
Merge branch 'main' into rs-decomp
josephleekl Nov 25, 2025
41b95e8
Update mlir/include/Quantum/Transforms/Passes.td
josephleekl Nov 25, 2025
c4e8a64
update RSDecomp file name
josephleekl Nov 25, 2025
b62ab74
use pennylane branch for doc
josephleekl Nov 25, 2025
a148efe
update dep version for docs
josephleekl Nov 25, 2025
f59612b
add test
josephleekl Nov 25, 2025
5cc42fb
add docs
josephleekl Nov 25, 2025
1ae3c54
Update frontend/catalyst/compiler.py
josephleekl Nov 25, 2025
07153e2
docs
josephleekl Nov 25, 2025
94e550f
small cleanup
josephleekl Nov 25, 2025
d8be26a
fix phaseshift phase
josephleekl Nov 25, 2025
17a0778
review comments
josephleekl Nov 28, 2025
80cabdf
isExternal -> getCallableRegion
paul0403 Dec 1, 2025
18dd269
Apply suggestions from code review
josephleekl Dec 2, 2025
fb2c012
Merge branch 'main' into rs-decomp
josephleekl Dec 2, 2025
ea84709
code review
josephleekl Dec 2, 2025
2977e98
Add `estimated_iterations` attribute to forOp in GridsynthPatterns
sengthai Dec 2, 2025
067a398
use heap alloc and add doc
josephleekl Dec 2, 2025
736b357
review comments - template declarefunc, alloc
josephleekl Dec 2, 2025
b454fbe
break up big getOrCreateDecompositionFunc function
josephleekl Dec 2, 2025
9ef0342
remove unnecessary using namespace
josephleekl Dec 2, 2025
e8cc971
Merge branch 'main' into rs-decomp
josephleekl Dec 2, 2025
22acc0d
add frontend lit test
josephleekl Dec 2, 2025
f4aeb50
codefactor
josephleekl Dec 2, 2025
a28ae7f
use size_t instead of int64_t for runtime func
josephleekl Dec 2, 2025
b78d2b7
update test
josephleekl Dec 2, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions doc/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ matplotlib==3.10.0
lxml_html_clean

# Pre-install PL development wheels
--extra-index-url https://test.pypi.org/simple/
pennylane-lightning-kokkos==0.44.0-dev16
pennylane-lightning==0.44.0-dev16
pennylane==0.44.0-dev42
git+https://github.com/PennyLaneAI/pennylane.git@rs-decomp
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

1 change: 1 addition & 0 deletions frontend/catalyst/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@
f"-l{lapack_lib_name}", # required for custom_calls lib
"-lcustom_calls",
"-lmlir_async_runtime",
"-lrt_rsdecomp",
]

# If OQD runtime capi is built, link to it as well
Expand Down Expand Up @@ -594,7 +595,7 @@
workspace: The workspace directory path

Returns:
Callable or None: The callback function if intermediate saving is enabled, None otherwise

Check notice on line 598 in frontend/catalyst/compiler.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/compiler.py#L598

Line too long (101/100) (line-too-long)
"""
if not (workspace and self.options.keep_intermediate >= KeepIntermediateLevel.CHANGED):
return None
Expand Down
2 changes: 2 additions & 0 deletions frontend/catalyst/from_plxpr/from_plxpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from pennylane.transforms import cancel_inverses as pl_cancel_inverses
from pennylane.transforms import commute_controlled as pl_commute_controlled
from pennylane.transforms import decompose as pl_decompose
from pennylane.transforms import gridsynth as pl_gridsynth
from pennylane.transforms import merge_amplitude_embedding as pl_merge_amplitude_embedding
from pennylane.transforms import merge_rotations as pl_merge_rotations
from pennylane.transforms import single_qubit_fusion as pl_single_qubit_fusion
Expand Down Expand Up @@ -341,6 +342,7 @@ def calling_convention(*args):
pl_merge_rotations: ("merge-rotations", False),
pl_single_qubit_fusion: (None, False),
pl_unitary_to_rot: (None, False),
pl_gridsynth: ("gridsynth", False),
}


Expand Down
2 changes: 2 additions & 0 deletions frontend/catalyst/passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
commute_ppr,
disentangle_cnot,
disentangle_swap,
gridsynth,
ions_decomposition,
merge_ppr_ppm,
merge_rotations,
Expand All @@ -51,6 +52,7 @@
from catalyst.passes.pass_api import Pass, PassPlugin, apply_pass, apply_pass_plugin

__all__ = (
"gridsynth",
"to_ppr",
"commute_ppr",
"ppr_to_ppm",
Expand Down
67 changes: 67 additions & 0 deletions frontend/catalyst/passes/builtin_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,73 @@ def circuit():
return PassPipelineWrapper(qnode, "ions-decomposition")


def gridsynth(qnode=None, *, epsilon=1e-4, ppr_basis=False):
"""
Specify that the ``--gridsynth`` MLIR compiler pass to discretize
single-qubit RZ and PhaseShift gates into sequences of
Clifford+T gates using the Ross-Selinger Gridsynth algorithm.
Reference: https://arxiv.org/abs/1403.2975


.. note::

The actual discretization is only performed during execution time.

Args:
fn (QNode): the QNode to apply the gridsynth compiler pass to
epsilon (float): the maximum error tolerance for the per-gate discretization
ppr_basis (bool): whether to decompose directly to Pauli Product Rotations (PPRs) in QEC dialect

Returns:
:class:`QNode <pennylane.QNode>`

**Example**

In this example the RZ gate will be converted into a new function, which
calls the discretization at execution time.

.. code-block:: python

from catalyst.passes import gridsynth

dev = qml.device("lightning.qubit", wires=1)

@qjit(keep_intermediate=True)
@gridsynth
@qml.qnode(dev)
def circuit(x: float):
qml.RZ(x, wires=0)
return qml.expval(qml.PauliZ(0))

Example MLIR Representation:

.. code-block:: mlir

module @circuit {
. . .
func.func private @rs_decomposition_get_phase_0(f64, f64, i1) -> f64
func.func private @rs_decomposition_get_gates_0(memref<?xindex>, f64, f64, i1)
func.func private @rs_decomposition_get_size_0(f64, f64, i1) -> index
func.func private @__catalyst_decompose_RZ_0(%arg0: !quantum.reg, %arg1: i64, %arg2: f64) -> (!quantum.reg, f64) {
. . .
}

func.func public @circuit_0(%arg0: tensor<f64>) -> tensor<f64> attributes {diff_method = "adjoint", llvm.linkage = #llvm.linkage<internal>, qnode} {
. . .
%3:2 = call @__catalyst_decompose_RZ_0(%2, %c0_i64, %extracted) : (!quantum.reg, i64, f64) -> (!quantum.reg, f64)
. . .
}
}


"""
if qnode is None:
return functools.partial(gridsynth, epsilon=epsilon, ppr_basis=ppr_basis)

gridsynth_pass = {"gridsynth": {"epsilon": epsilon, "ppr_basis": ppr_basis}}
return PassPipelineWrapper(qnode, gridsynth_pass)


def to_ppr(qnode):
R"""
A quantum compilation pass that converts Clifford+T gates into Pauli Product Rotation (PPR)
Expand Down
1 change: 1 addition & 0 deletions frontend/catalyst/passes/pass_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,7 @@ def dictionary_to_list_of_passes(pass_pipeline: PipelineDict | str, *flags, **va

def _API_name_to_pass_name():
return {
"gridsynth": "gridsynth",
"cancel_inverses": "cancel-inverses",
"decompose_lowering": "decompose-lowering",
"disentangle_cnot": "disentangle-CNOT",
Expand Down
266 changes: 266 additions & 0 deletions frontend/test/lit/test_gridsynth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
# Copyright 2025 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Unit tests for the gridsynth decomposition pass.
"""

# RUN: %PYTHON %s | FileCheck %s
# pylint: disable=line-too-long

from functools import partial

import pennylane as qml

from catalyst import qjit
from catalyst.passes import gridsynth

# Pipeline to stop after quantum compilation (where gridsynth runs)
# This prevents lowerings that might fail for qec.ppr.
pipe = [("pipe", ["quantum-compilation-stage"])]

# ==============================================================================
# Test 1: RZ Registration (Clifford+T basis)
# ==============================================================================


def test_rz_registration():
"""Test that the gridsynth pass is correctly registered for RZ."""

@qjit(target="mlir")
@gridsynth(epsilon=0.01)
@qml.qnode(qml.device("lightning.qubit", wires=1))
def circuit(x: float):
qml.RZ(x, wires=0)
return qml.probs()

# CHECK-LABEL: test_rz_registration
print("test_rz_registration")
# CHECK: transform.named_sequence @__transform_main
# CHECK: transform.apply_registered_pass "gridsynth" with options = {{[{]}}"epsilon" = 1.000000e-02 : f64, "ppr-basis" = false{{[}]}}
# CHECK-LABEL: func.func public @circuit
# CHECK: quantum.custom "RZ"
print(circuit.mlir)


test_rz_registration()

# ==============================================================================
# Test 2: RZ Lowering (Clifford+T basis)
# ==============================================================================


def test_rz_lowering():
"""Test that RZ is correctly lowered to the decomposition function."""

@qjit(target="mlir", pipelines=pipe)
@gridsynth(epsilon=0.01)
@qml.qnode(qml.device("lightning.qubit", wires=1))
def circuit(x: float):
qml.RZ(x, wires=0)
return qml.probs()

# CHECK-LABEL: test_rz_lowering
print("test_rz_lowering")

# CHECK-LABEL: func.func private @__catalyst_decompose_RZ{{.*}}
# CHECK: scf.index_switch
# CHECK: case 0 {
# CHECK: quantum.custom "T"
# CHECK: }

# CHECK-LABEL: func.func public @circuit{{.*}}
# CHECK-NOT: quantum.custom "RZ"
# CHECK: call @__catalyst_decompose_RZ{{.*}}
print(circuit.mlir_opt)


test_rz_lowering()

# ==============================================================================
# Test 3: PhaseShift Registration
# ==============================================================================


def test_phaseshift_registration():
"""Test that the gridsynth pass is correctly registered for PhaseShift."""

@qjit(target="mlir")
@gridsynth(epsilon=0.01)
@qml.qnode(qml.device("lightning.qubit", wires=1))
def circuit(x: float):
qml.PhaseShift(x, wires=0)
return qml.probs()

# CHECK-LABEL: test_phaseshift_registration
print("test_phaseshift_registration")
# CHECK: transform.apply_registered_pass "gridsynth"
# CHECK-LABEL: func.func public @circuit
# CHECK: quantum.custom "PhaseShift"
print(circuit.mlir)


test_phaseshift_registration()

# ==============================================================================
# Test 4: PhaseShift Lowering
# ==============================================================================


def test_phaseshift_lowering():
"""Test that PhaseShift is decomposed into RZ + GlobalPhase."""

@qjit(target="mlir", pipelines=pipe)
@gridsynth(epsilon=0.01)
@qml.qnode(qml.device("lightning.qubit", wires=1))
def circuit(x: float):
qml.PhaseShift(x, wires=0)
return qml.probs()

# CHECK-LABEL: test_phaseshift_lowering
print("test_phaseshift_lowering")

# CHECK-LABEL: func.func private @__catalyst_decompose_RZ{{.*}}

# CHECK-LABEL: func.func public @circuit{{.*}}
# CHECK: call @__catalyst_decompose_RZ{{.*}}
# CHECK: quantum.gphase
print(circuit.mlir_opt)


test_phaseshift_lowering()

# ==============================================================================
# Test 5: PPR Registration
# ==============================================================================


def test_ppr_registration():
"""Test that ppr_basis=True is passed to the transform."""

@qjit(target="mlir")
@gridsynth(epsilon=0.01, ppr_basis=True)
@qml.qnode(qml.device("lightning.qubit", wires=1))
def circuit(x: float):
qml.RZ(x, wires=0)
return qml.probs()

# CHECK-LABEL: test_ppr_registration
print("test_ppr_registration")
# CHECK: transform.apply_registered_pass "gridsynth" with options = {{[{]}}"epsilon" = 1.000000e-02 : f64, "ppr-basis" = true{{[}]}}
print(circuit.mlir)


test_ppr_registration()

# ==============================================================================
# Test 6: PPR Lowering
# ==============================================================================


def test_ppr_lowering():
"""Test that PPR basis generates qec.ppr operations."""

@qjit(target="mlir", pipelines=pipe)
@gridsynth(epsilon=0.01, ppr_basis=True)
@qml.qnode(qml.device("lightning.qubit", wires=1))
def circuit(x: float):
qml.RZ(x, wires=0)
return qml.probs()

# CHECK-LABEL: test_ppr_lowering
print("test_ppr_lowering")

# CHECK-LABEL: func.func private @__catalyst_decompose_RZ_ppr_basis{{.*}}
# CHECK: scf.index_switch
# CHECK: case 1 {
# CHECK: qec.ppr ["X"](2)
# CHECK: }

# CHECK-LABEL: func.func public @circuit{{.*}}
# CHECK: call @__catalyst_decompose_RZ_ppr_basis{{.*}}
print(circuit.mlir_opt)


test_ppr_lowering()


# ==============================================================================
# Test 7: Capture Workflow Lowering (Clifford+T)
# ==============================================================================


def test_capture_workflow_clifford():
"""Test the capture workflow with qml.transforms.gridsynth (Clifford+T)."""
qml.capture.enable()

@qjit(target="mlir", pipelines=pipe)
@partial(qml.transforms.gridsynth, epsilon=0.01, ppr_basis=False)
@qml.qnode(qml.device("lightning.qubit", wires=1))
def circuit(x: float):
qml.RZ(x, wires=0)
return qml.probs()

# CHECK-LABEL: test_capture_workflow_clifford
print("test_capture_workflow_clifford")

# CHECK-LABEL: func.func private @__catalyst_decompose_RZ{{.*}}
# CHECK: scf.index_switch
# CHECK: case 0 {
# CHECK: quantum.custom "T"
# CHECK: }

# CHECK-LABEL: func.func public @circuit{{.*}}
# CHECK-NOT: quantum.custom "RZ"
# CHECK: call @__catalyst_decompose_RZ{{.*}}
print(circuit.mlir_opt)

qml.capture.disable()


test_capture_workflow_clifford()

# ==============================================================================
# Test 8: Capture Workflow Lowering (PPR)
# ==============================================================================


def test_capture_workflow_ppr():
"""Test the capture workflow with qml.transforms.gridsynth (PPR)."""
qml.capture.enable()

@qjit(target="mlir", pipelines=pipe)
@partial(qml.transforms.gridsynth, epsilon=0.01, ppr_basis=True)
@qml.qnode(qml.device("lightning.qubit", wires=1))
def circuit(x: float):
qml.RZ(x, wires=0)
return qml.probs()

# CHECK-LABEL: test_capture_workflow_ppr
print("test_capture_workflow_ppr")

# CHECK-LABEL: func.func private @__catalyst_decompose_RZ_ppr_basis{{.*}}
# CHECK: scf.index_switch
# CHECK: case 1 {
# CHECK: qec.ppr ["X"](2)
# CHECK: }

# CHECK-LABEL: func.func public @circuit{{.*}}
# CHECK: call @__catalyst_decompose_RZ_ppr_basis{{.*}}
print(circuit.mlir_opt)

qml.capture.disable()


test_capture_workflow_ppr()
Loading
Loading