Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create a frontend UI for users to specify quantum compilation pipelines #1131

Merged
merged 45 commits into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
69c7933
pulling ZNE mitigation (`-lower-mitigation`) into the transform_named…
paul0403 Sep 11, 2024
6974981
Merge remote-tracking branch 'origin/main' into pass_pipeline_UI
paul0403 Sep 11, 2024
4761364
update passes.mitigate_with_zne test to include new `scale_factors`
paul0403 Sep 11, 2024
b5ec1b5
Merge remote-tracking branch 'origin/main' into pass_pipeline_UI
paul0403 Sep 11, 2024
13899c1
Merge remote-tracking branch 'origin/main' into pass_pipeline_UI
paul0403 Sep 12, 2024
2a38c02
init the local pipeline decorator
paul0403 Sep 12, 2024
986cc94
lit tests
paul0403 Sep 13, 2024
d7b98ae
add global peephole pipeline option in qjit
paul0403 Sep 13, 2024
00b28d7
format
paul0403 Sep 13, 2024
4efb372
Merge remote-tracking branch 'origin/main' into pass_pipeline_UI
paul0403 Sep 16, 2024
300c61c
reverting zne changes; this PR will leave zne untouched
paul0403 Sep 16, 2024
30b9f41
create merge rotation pass boilerplate
paul0403 Sep 16, 2024
99a3f13
reverting zne to main
paul0403 Sep 16, 2024
c23f52b
rewriting tests to exclude zne
paul0403 Sep 16, 2024
6f5dd59
put back lower-mitigation in default pipeline
paul0403 Sep 16, 2024
7dc40cd
codefactor
paul0403 Sep 16, 2024
1be95b5
local pipelines override global pipelines
paul0403 Sep 16, 2024
50f2585
make sure cudajit (which will never have the pass_pipeline) does not …
paul0403 Sep 16, 2024
d0d29c2
format
paul0403 Sep 16, 2024
60769e3
codefactor
paul0403 Sep 17, 2024
7648f88
Merge remote-tracking branch 'origin/main' into pass_pipeline_UI
paul0403 Sep 17, 2024
bc521c9
format
paul0403 Sep 17, 2024
8ed20c5
add support for pass options
paul0403 Sep 17, 2024
652d96a
format
paul0403 Sep 17, 2024
46f640f
Merge remote-tracking branch 'origin/main' into pass_pipeline_UI
paul0403 Sep 17, 2024
8d24767
type hint pipeline
paul0403 Sep 17, 2024
d410afb
no documentation for helpers in passes.py
paul0403 Sep 17, 2024
db3bd48
format
paul0403 Sep 17, 2024
551545d
add pytest for pass option effect
paul0403 Sep 17, 2024
6b20fc5
documentation
paul0403 Sep 17, 2024
6618e97
changelog
paul0403 Sep 17, 2024
72e6fc0
codefactor line too long in documentation
paul0403 Sep 17, 2024
a3b5425
Merge remote-tracking branch 'origin/main' into pass_pipeline_UI
paul0403 Sep 18, 2024
be02cad
typo
paul0403 Sep 18, 2024
9195851
add quantum scope TODO
paul0403 Sep 18, 2024
0850516
double ticks in documentation instead of single tick for code words
paul0403 Sep 18, 2024
2eb6f2f
Merge remote-tracking branch 'origin/main' into pass_pipeline_UI
paul0403 Sep 18, 2024
d2e5abb
Merge remote-tracking branch 'origin/main' into pass_pipeline_UI
paul0403 Sep 18, 2024
2ddf0af
rename variables in changelog
paul0403 Sep 18, 2024
8ae031f
change name uniquer to an import
paul0403 Sep 18, 2024
536d651
Merge remote-tracking branch 'origin/main' into pass_pipeline_UI
paul0403 Sep 19, 2024
f0a5bb6
Merge remote-tracking branch 'origin/main' into pass_pipeline_UI
paul0403 Sep 19, 2024
1511e17
remove "merge_rotations" from public documentation
paul0403 Sep 19, 2024
7c2142e
typo
paul0403 Sep 19, 2024
81b5006
add web link to `catalyst.passes` in documentation
paul0403 Sep 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion frontend/catalyst/api_extensions/error_mitigation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
import pennylane as qml
from jax._src.tree_util import tree_flatten

from catalyst.jax_primitives import Folding, zne_p
from catalyst.jax_primitives import Folding, apply_registered_pass_p, zne_p
from catalyst.tracing.contexts import EvaluationContext


def _is_odd_positive(numbers_list):
Expand Down Expand Up @@ -164,6 +165,15 @@ def __init__(
):
if not isinstance(fn, qml.QNode):
raise TypeError(f"A QNode is expected, got the classical function {fn}")

wrapped_qnode_function = fn.func

def wrapper(*args, **kwrags):
if EvaluationContext.is_tracing():
apply_registered_pass_p.bind(pass_name="lower-mitigation")
return wrapped_qnode_function(*args, **kwrags)

fn.func = wrapper
self.fn = fn
self.__name__ = f"zne.{getattr(fn, '__name__', 'unknown')}"
self.num_folds = num_folds
Expand Down
1 change: 0 additions & 1 deletion frontend/catalyst/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,6 @@ def run_writing_command(command: List[str], compile_options: Optional[CompileOpt
[
"apply-transform-sequence", # Run the transform sequence defined in the MLIR module
"annotate-function",
"lower-mitigation",
"lower-gradients",
"adjoint-lowering",
"disable-assertion",
Expand Down
15 changes: 14 additions & 1 deletion frontend/catalyst/jax_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,8 @@
"""

# If there already is a apply_registered_pass,
# insert after the last pass in the existing pass sequence.
# insert before the first pass in the existing pass sequence.
# See comment in frontend/catalyst/passes.py/pipeline()
# Note that ir.InsertionPoint(op) sets the insertion point to immediately BEFORE the op
named_sequence_op_block = named_sequence_op.regions[0].blocks[0]
first_op_in_block = named_sequence_op_block.operations[0].operation
Expand All @@ -528,16 +529,28 @@
"""

if first_op_in_block.name == "transform.apply_registered_pass":
"""

Check notice on line 532 in frontend/catalyst/jax_primitives.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/jax_primitives.py#L532

String statement has no effect (pointless-string-statement)
paul0403 marked this conversation as resolved.
Show resolved Hide resolved
_ = len(named_sequence_op_block.operations)
yield_op = named_sequence_op_block.operations[_ - 1].operation
current_last_pass = named_sequence_op_block.operations[_ - 2].operation

with ir.InsertionPoint(yield_op):
apply_registered_pass_op = ApplyRegisteredPassOp(
result=transform_mod_type,
target=current_last_pass.result,
pass_name=pass_name,
options=options,
)
"""
current_first_pass = named_sequence_op_block.operations[0].operation
with ir.InsertionPoint(first_op_in_block):
apply_registered_pass_op = ApplyRegisteredPassOp(
result=transform_mod_type,
target=named_sequence_op.regions[0].blocks[0].arguments[0],
pass_name=pass_name,
options=options,
)
current_first_pass.operands[0] = apply_registered_pass_op.result

# otherwise it's the first pass, i.e. only a yield op is in the block
# so insert right before the yield op
Expand Down
120 changes: 99 additions & 21 deletions frontend/catalyst/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,65 @@
"""

import copy
import functools

import pennylane as qml

from catalyst.api_extensions import (
mitigate_with_zne as _mitigate_with_zne_api_extensions,
)
from catalyst.jax_primitives import apply_registered_pass_p, transform_named_sequence_p
from catalyst.tracing.contexts import EvaluationContext


## API ##
# pylint: disable=line-too-long
def cancel_inverses(fn=None):
def pipeline(fn=None, *, pass_pipeline=None):
paul0403 marked this conversation as resolved.
Show resolved Hide resolved
"""
Here are documentation words
"""
paul0403 marked this conversation as resolved.
Show resolved Hide resolved

"""

Check notice on line 54 in frontend/catalyst/passes.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/passes.py#L54

String statement has no effect (pointless-string-statement)
Implementation design: it just stacks the decorators for the user automatically.
e.g.pass_pipeline = {
"mitigate_with_zne": {"scale_factors": [1, 2, 3]},
"cancel_inverses": {}
}
will just do
fn = mitigate_with_zne(scale_factors=[1, 2, 3])(fn)
fn = cancel_inverses(fn)

Note that since python 3.7, dictionary order is guaranteed to be in insertion order
https://docs.python.org/3/library/stdtypes.html#dict.values
"""
kwargs = copy.copy(locals())
kwargs.pop("fn")

if fn is None:
return functools.partial(pipeline, **kwargs)

if not isinstance(fn, qml.QNode):
raise TypeError(f"A QNode is expected, got the classical function {fn}")

fn_original_name = fn.__name__

API_calls = API_name_to_API_calls()
fn_clone = copy.copy(fn)
fn_clone.__name__ = fn_original_name + "_transformed"
# Note: we create wrappers to inject the apply_registered_pass primitive
# In other words, the last API call, aka the last wrapper, will be the outermost primitive
# Therefore when lowering these primitives to mlir (in jax_primitives.py), lower them in reverse order!
for decorator, decorator_args in pass_pipeline.items():
print("seeing pass: ", decorator)
if decorator not in API_calls.keys():
raise RuntimeError(f"{decorator} is not a valid quantum transformation pass")

fn_clone = API_calls[decorator](fn_clone, keep_original=False, **decorator_args)

return fn_clone


def cancel_inverses(fn=None, keep_original=True):
"""
Specify that the ``-removed-chained-self-inverse`` MLIR compiler pass
for cancelling two neighbouring self-inverse
Expand Down Expand Up @@ -137,33 +187,61 @@
if not isinstance(fn, qml.QNode):
raise TypeError(f"A QNode is expected, got the classical function {fn}")

wrapped_qnode_function = fn.func
funcname = fn.__name__
wrapped_qnode_function = fn.func

def wrapper(*args, **kwrags):
# TODO: hint the compiler which qnodes to run the pass on via an func attribute,
# instead of the qnode name. That way the clone can have this attribute and
# the original can just not have it.
# We are not doing this right now and passing by name because this would
# be a discardable attribute (i.e. a user/developer wouldn't know that this
# attribute exists just by looking at qnode's documentation)
# But when we add the full peephole pipeline in the future, the attribute
# could get properly documented.

apply_registered_pass_p.bind(
pass_name="remove-chained-self-inverse",
options=f"func-name={funcname}" + "_cancel_inverses",
)
return wrapped_qnode_function(*args, **kwrags)
if keep_original:

fn_clone = copy.copy(fn)
fn_clone.func = wrapper
fn_clone.__name__ = funcname + "_cancel_inverses"
def wrapper(*args, **kwrags):
# TODO: hint the compiler which qnodes to run the pass on via an func attribute,
# instead of the qnode name. That way the clone can have this attribute and
# the original can just not have it.
# We are not doing this right now and passing by name because this would
# be a discardable attribute (i.e. a user/developer wouldn't know that this
# attribute exists just by looking at qnode's documentation)
# But when we add the full peephole pipeline in the future, the attribute
# could get properly documented.

return fn_clone
if EvaluationContext.is_tracing():
apply_registered_pass_p.bind(
pass_name="remove-chained-self-inverse",
options=f"func-name={funcname}" + "_cancel_inverses",
)
return wrapped_qnode_function(*args, **kwrags)

fn_clone = copy.copy(fn)
fn_clone.func = wrapper
fn_clone.__name__ = funcname + "_cancel_inverses"

return fn_clone

else:

def wrapper(*args, **kwrags):
if EvaluationContext.is_tracing():
apply_registered_pass_p.bind(
pass_name="remove-chained-self-inverse",
options=f"func-name={funcname}",
)
return wrapped_qnode_function(*args, **kwrags)

Check notice on line 226 in frontend/catalyst/passes.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/passes.py#L226

Missing function or method docstring (missing-function-docstring)

fn.func = wrapper
return fn


def mitigate_with_zne(*args, keep_original=False, **kwrags):

Check notice on line 232 in frontend/catalyst/passes.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/passes.py#L232

Unused argument 'keep_original' (unused-argument)
"""
An alias of catalyst.mitigate_with_zne.
See https://docs.pennylane.ai/projects/catalyst/en/stable/code/api/catalyst.mitigate_with_zne.html
"""
return _mitigate_with_zne_api_extensions(*args, **kwrags)


## IMPL and helpers ##
def API_name_to_API_calls():
return {"cancel_inverses": cancel_inverses, "mitigate_with_zne": mitigate_with_zne}


def _inject_transform_named_sequence():
"""
Inject a transform_named_sequence jax primitive.
Expand Down
122 changes: 120 additions & 2 deletions frontend/test/lit/test_peephole_optimizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@
from lit_util_printers import print_jaxpr, print_mlir

from catalyst import qjit
from catalyst.api_extensions.error_mitigation import polynomial_extrapolation
from catalyst.debug import get_compilation_stage
from catalyst.passes import cancel_inverses
from catalyst.passes import cancel_inverses, pipeline


def flush_peephole_opted_mlir_to_iostream(QJIT):
Expand Down Expand Up @@ -75,6 +76,123 @@
test_transform_named_sequence_injection()


#
# pipeline
#


def test_pipeline_lowering():

Check notice on line 84 in frontend/test/lit/test_peephole_optimizations.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/test/lit/test_peephole_optimizations.py#L84

Missing function or method docstring (missing-function-docstring)
my_pipeline = {
"cancel_inverses": {},
"mitigate_with_zne": {
paul0403 marked this conversation as resolved.
Show resolved Hide resolved
"scale_factors": [1, 3, 5, 7],
"extrapolate": polynomial_extrapolation(2),
"folding": "global",
},
}

@qjit(keep_intermediate=True)
@pipeline(pass_pipeline=my_pipeline)
@qml.qnode(qml.device("lightning.qubit", wires=2))
def test_pipeline_lowering_workflow(x):
qml.RX(x, wires=[0])
qml.Hadamard(wires=[1])
qml.Hadamard(wires=[1])
return qml.expval(qml.PauliY(wires=0))

# CHECK: transform_named_sequence
# CHECK: _:AbstractTransformMod() = apply_registered_pass[
# CHECK: pass_name=lower-mitigation
# CHECK: ]
# CHECK: _:AbstractTransformMod() = apply_registered_pass[
# CHECK: options=func-name=test_pipeline_lowering_workflow_transformed
# CHECK: pass_name=remove-chained-self-inverse
# CHECK: ]
print_jaxpr(test_pipeline_lowering_workflow, 1.2)

# CHECK: transform.named_sequence @__transform_main
# CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "remove-chained-self-inverse" to {{%.+}} {options = "func-name=test_pipeline_lowering_workflow_transformed"}
# CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "lower-mitigation" to {{%.+}}
# CHECK-NEXT: transform.yield
print_mlir(test_pipeline_lowering_workflow, 1.2)

# CHECK: func.func private @test_pipeline_lowering_workflow_transformed.withMeasurements
# CHECK: {{%.+}} = quantum.custom "RX"({{%.+}}) {{%.+}} : !quantum.bit
# CHECK-NOT: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit
# CHECK-NOT: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit
# CHECK: func.func public @jit_zne.test_pipeline_lowering_workflow_transformed
# CHECK: func.func private @test_pipeline_lowering_workflow_transformed
# CHECK: {{%.+}} = quantum.custom "RX"({{%.+}}) {{%.+}} : !quantum.bit
# CHECK-NOT: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit
# CHECK-NOT: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit
test_pipeline_lowering_workflow(42.42)
flush_peephole_opted_mlir_to_iostream(test_pipeline_lowering_workflow)


test_pipeline_lowering()


def test_pipeline_lowering_keep_original():

Check notice on line 135 in frontend/test/lit/test_peephole_optimizations.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/test/lit/test_peephole_optimizations.py#L135

Missing function or method docstring (missing-function-docstring)
my_pipeline = {
"cancel_inverses": {},
"mitigate_with_zne": {
"scale_factors": [1, 3, 5, 7],
"extrapolate": polynomial_extrapolation(2),
"folding": "global",
},
}

@qml.qnode(qml.device("lightning.qubit", wires=2))
def f(x):
qml.RX(x, wires=[0])
qml.Hadamard(wires=[1])
qml.Hadamard(wires=[1])
return qml.expval(qml.PauliY(wires=0))

f_pipeline = pipeline(pass_pipeline=my_pipeline)(f)

@qjit(keep_intermediate=True)
def test_pipeline_lowering_keep_original_workflow(x):

Check notice on line 155 in frontend/test/lit/test_peephole_optimizations.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/test/lit/test_peephole_optimizations.py#L155

Unused argument 'x' (unused-argument)
return f(1.2), f_pipeline(1.2)

# CHECK: transform_named_sequence
# CHECK: _:AbstractTransformMod() = apply_registered_pass[
# CHECK: pass_name=lower-mitigation
# CHECK: ]
# CHECK: _:AbstractTransformMod() = apply_registered_pass[
# CHECK: options=func-name=f_transformed
# CHECK: pass_name=remove-chained-self-inverse
# CHECK: ]
print_jaxpr(test_pipeline_lowering_keep_original_workflow, 1.2)

# CHECK: transform.named_sequence @__transform_main
# CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "remove-chained-self-inverse" to {{%.+}} {options = "func-name=f_transformed"}
# CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "lower-mitigation" to {{%.+}}
# CHECK-NEXT: transform.yield
print_mlir(test_pipeline_lowering_keep_original_workflow, 1.2)

# CHECK: func.func private @f_transformed.withMeasurements
# CHECK: {{%.+}} = quantum.custom "RX"({{%.+}}) {{%.+}} : !quantum.bit
# CHECK-NOT: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit
# CHECK-NOT: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit
# CHECK: func.func public @jit_test_pipeline_lowering_keep_original_workflow
# CHECK: {{%.+}} = call @f(
# CHECK: {{%.+}} = func.call @f_transformed.folded(
# CHECK: func.func private @f(
# CHECK: {{%.+}} = quantum.custom "RX"({{%.+}}) {{%.+}} : !quantum.bit
# CHECK: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit
# CHECK: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit
# CHECK: func.func private @f_transformed
# CHECK: {{%.+}} = quantum.custom "RX"({{%.+}}) {{%.+}} : !quantum.bit
# CHECK-NOT: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit
# CHECK-NOT: {{%.+}} = quantum.custom "Hadamard"() {{%.+}} : !quantum.bit
test_pipeline_lowering_keep_original_workflow(42.42)
flush_peephole_opted_mlir_to_iostream(test_pipeline_lowering_keep_original_workflow)


test_pipeline_lowering_keep_original()


#
# cancel_inverses
#
Expand Down Expand Up @@ -132,8 +250,8 @@

# CHECK: module @test_cancel_inverses_tracing_and_lowering_workflow
# CHECK: transform.named_sequence @__transform_main
# CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "remove-chained-self-inverse" to {{%.+}} {options = "func-name=f_cancel_inverses"}
# CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "remove-chained-self-inverse" to {{%.+}} {options = "func-name=g_cancel_inverses"}
# CHECK-NEXT: {{%.+}} = transform.apply_registered_pass "remove-chained-self-inverse" to {{%.+}} {options = "func-name=f_cancel_inverses"}
# CHECK-NOT: {{%.+}} = transform.apply_registered_pass "remove-chained-self-inverse" to {{%.+}} {options = "func-name=h_cancel_inverses"}
# CHECK-NEXT: transform.yield
print_mlir(test_cancel_inverses_tracing_and_lowering_workflow, 1.1)
Expand Down
11 changes: 11 additions & 0 deletions frontend/test/pytest/test_mitigation.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,17 @@ def mitigated_qnode(args):

assert np.allclose(mitigated_qnode(params), circuit(params))

@catalyst.qjit
def mitigated_qnode_passes(args):
return catalyst.passes.mitigate_with_zne(
circuit,
scale_factors=scale_factors,
extrapolate=extrapolation,
folding=folding,
)(args)

assert np.allclose(mitigated_qnode_passes(params), circuit(params))


@pytest.mark.parametrize("params", [0.1, 0.2, 0.3, 0.4, 0.5])
@pytest.mark.parametrize("extrapolation", [quadratic_extrapolation, exponential_extrapolate])
Expand Down
Loading
Loading