Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create a frontend UI for users to specify quantum compilation pipelines #1131

Merged
merged 45 commits into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
69c7933
pulling ZNE mitigation (`-lower-mitigation`) into the transform_named…
paul0403 Sep 11, 2024
6974981
Merge remote-tracking branch 'origin/main' into pass_pipeline_UI
paul0403 Sep 11, 2024
4761364
update passes.mitigate_with_zne test to include new `scale_factors`
paul0403 Sep 11, 2024
b5ec1b5
Merge remote-tracking branch 'origin/main' into pass_pipeline_UI
paul0403 Sep 11, 2024
13899c1
Merge remote-tracking branch 'origin/main' into pass_pipeline_UI
paul0403 Sep 12, 2024
2a38c02
init the local pipeline decorator
paul0403 Sep 12, 2024
986cc94
lit tests
paul0403 Sep 13, 2024
d7b98ae
add global peephole pipeline option in qjit
paul0403 Sep 13, 2024
00b28d7
format
paul0403 Sep 13, 2024
4efb372
Merge remote-tracking branch 'origin/main' into pass_pipeline_UI
paul0403 Sep 16, 2024
300c61c
reverting zne changes; this PR will leave zne untouched
paul0403 Sep 16, 2024
30b9f41
create merge rotation pass boilerplate
paul0403 Sep 16, 2024
99a3f13
reverting zne to main
paul0403 Sep 16, 2024
c23f52b
rewriting tests to exclude zne
paul0403 Sep 16, 2024
6f5dd59
put back lower-mitigation in default pipeline
paul0403 Sep 16, 2024
7dc40cd
codefactor
paul0403 Sep 16, 2024
1be95b5
local pipelines override global pipelines
paul0403 Sep 16, 2024
50f2585
make sure cudajit (which will never have the pass_pipeline) does not …
paul0403 Sep 16, 2024
d0d29c2
format
paul0403 Sep 16, 2024
60769e3
codefactor
paul0403 Sep 17, 2024
7648f88
Merge remote-tracking branch 'origin/main' into pass_pipeline_UI
paul0403 Sep 17, 2024
bc521c9
format
paul0403 Sep 17, 2024
8ed20c5
add support for pass options
paul0403 Sep 17, 2024
652d96a
format
paul0403 Sep 17, 2024
46f640f
Merge remote-tracking branch 'origin/main' into pass_pipeline_UI
paul0403 Sep 17, 2024
8d24767
type hint pipeline
paul0403 Sep 17, 2024
d410afb
no documentation for helpers in passes.py
paul0403 Sep 17, 2024
db3bd48
format
paul0403 Sep 17, 2024
551545d
add pytest for pass option effect
paul0403 Sep 17, 2024
6b20fc5
documentation
paul0403 Sep 17, 2024
6618e97
changelog
paul0403 Sep 17, 2024
72e6fc0
codefactor line too long in documentation
paul0403 Sep 17, 2024
a3b5425
Merge remote-tracking branch 'origin/main' into pass_pipeline_UI
paul0403 Sep 18, 2024
be02cad
typo
paul0403 Sep 18, 2024
9195851
add quantum scope TODO
paul0403 Sep 18, 2024
0850516
double ticks in documentation instead of single tick for code words
paul0403 Sep 18, 2024
2eb6f2f
Merge remote-tracking branch 'origin/main' into pass_pipeline_UI
paul0403 Sep 18, 2024
d2e5abb
Merge remote-tracking branch 'origin/main' into pass_pipeline_UI
paul0403 Sep 18, 2024
2ddf0af
rename variables in changelog
paul0403 Sep 18, 2024
8ae031f
change name uniquer to an import
paul0403 Sep 18, 2024
536d651
Merge remote-tracking branch 'origin/main' into pass_pipeline_UI
paul0403 Sep 19, 2024
f0a5bb6
Merge remote-tracking branch 'origin/main' into pass_pipeline_UI
paul0403 Sep 19, 2024
1511e17
remove "merge_rotations" from public documentation
paul0403 Sep 19, 2024
7c2142e
typo
paul0403 Sep 19, 2024
81b5006
add web link to `catalyst.passes` in documentation
paul0403 Sep 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion frontend/catalyst/api_extensions/error_mitigation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
import pennylane as qml
from jax._src.tree_util import tree_flatten

from catalyst.jax_primitives import Folding, zne_p
from catalyst.jax_primitives import Folding, apply_registered_pass_p, zne_p
from catalyst.tracing.contexts import EvaluationContext


def _is_odd_positive(numbers_list):
Expand Down Expand Up @@ -164,6 +165,15 @@ def __init__(
):
if not isinstance(fn, qml.QNode):
raise TypeError(f"A QNode is expected, got the classical function {fn}")

wrapped_qnode_function = fn.func

def wrapper(*args, **kwrags):
if EvaluationContext.is_tracing():
apply_registered_pass_p.bind(pass_name="lower-mitigation")
return wrapped_qnode_function(*args, **kwrags)

fn.func = wrapper
self.fn = fn
self.__name__ = f"zne.{getattr(fn, '__name__', 'unknown')}"
self.num_folds = num_folds
Expand Down
2 changes: 1 addition & 1 deletion frontend/catalyst/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ class CompileOptions:
checkpoint_stage: Optional[str] = ""
disable_assertions: Optional[bool] = False
seed: Optional[int] = None
circuit_transform_pipeline: Optional[dict] = None

def __post_init__(self):
# Check that async runs must not be seeded
Expand Down Expand Up @@ -196,7 +197,6 @@ def run_writing_command(command: List[str], compile_options: Optional[CompileOpt
[
"apply-transform-sequence", # Run the transform sequence defined in the MLIR module
"annotate-function",
"lower-mitigation",
"lower-gradients",
"adjoint-lowering",
"disable-assertion",
Expand Down
15 changes: 14 additions & 1 deletion frontend/catalyst/jax_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,8 @@
"""

# If there already is a apply_registered_pass,
# insert after the last pass in the existing pass sequence.
# insert before the first pass in the existing pass sequence.
# See comment in frontend/catalyst/passes.py/pipeline()
# Note that ir.InsertionPoint(op) sets the insertion point to immediately BEFORE the op
named_sequence_op_block = named_sequence_op.regions[0].blocks[0]
first_op_in_block = named_sequence_op_block.operations[0].operation
Expand All @@ -528,16 +529,28 @@
"""

if first_op_in_block.name == "transform.apply_registered_pass":
"""

Check notice on line 532 in frontend/catalyst/jax_primitives.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/jax_primitives.py#L532

String statement has no effect (pointless-string-statement)
paul0403 marked this conversation as resolved.
Show resolved Hide resolved
_ = len(named_sequence_op_block.operations)
yield_op = named_sequence_op_block.operations[_ - 1].operation
current_last_pass = named_sequence_op_block.operations[_ - 2].operation

with ir.InsertionPoint(yield_op):
apply_registered_pass_op = ApplyRegisteredPassOp(
result=transform_mod_type,
target=current_last_pass.result,
pass_name=pass_name,
options=options,
)
"""
current_first_pass = named_sequence_op_block.operations[0].operation
with ir.InsertionPoint(first_op_in_block):
apply_registered_pass_op = ApplyRegisteredPassOp(
result=transform_mod_type,
target=named_sequence_op.regions[0].blocks[0].arguments[0],
pass_name=pass_name,
options=options,
)
current_first_pass.operands[0] = apply_registered_pass_op.result

# otherwise it's the first pass, i.e. only a yield op is in the block
# so insert right before the yield op
Expand Down
9 changes: 8 additions & 1 deletion frontend/catalyst/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from catalyst.jax_tracer import lower_jaxpr_to_mlir, trace_to_jaxpr
from catalyst.logging import debug_logger, debug_logger_init
from catalyst.passes import _inject_transform_named_sequence
from catalyst.passes import pipeline as circuit_transform_pass_pipeline

Check notice on line 41 in frontend/catalyst/jit.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/jit.py#L41

Unused pipeline imported from catalyst.passes as circuit_transform_pass_pipeline (unused-import)
from catalyst.qfunc import QFunc
from catalyst.tracing.contexts import EvaluationContext
from catalyst.tracing.type_signatures import (
Expand Down Expand Up @@ -83,6 +84,7 @@
abstracted_axes=None,
disable_assertions=False,
seed=None,
circuit_transform_pipeline=None,
): # pylint: disable=too-many-arguments,unused-argument
"""A just-in-time decorator for PennyLane and JAX programs using Catalyst.

Expand Down Expand Up @@ -585,7 +587,12 @@
params = {}
params["static_argnums"] = kwargs.pop("static_argnums", static_argnums)
params["_out_tree_expected"] = []
return QFunc.__call__(qnode, *args, **dict(params, **kwargs))
return QFunc.__call__(
qnode,
self.compile_options.circuit_transform_pipeline,
*args,
**dict(params, **kwargs),
)

with Patcher(
(qml.QNode, "__call__", closure),
Expand Down
124 changes: 103 additions & 21 deletions frontend/catalyst/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,69 @@
"""

import copy
import functools

import pennylane as qml

from catalyst.api_extensions import (
mitigate_with_zne as _mitigate_with_zne_api_extensions,
)
from catalyst.jax_primitives import apply_registered_pass_p, transform_named_sequence_p
from catalyst.tracing.contexts import EvaluationContext


## API ##
# pylint: disable=line-too-long
def cancel_inverses(fn=None):
def pipeline(fn=None, *, pass_pipeline=None):
paul0403 marked this conversation as resolved.
Show resolved Hide resolved
"""
Here are documentation words
"""
paul0403 marked this conversation as resolved.
Show resolved Hide resolved

"""

Check notice on line 54 in frontend/catalyst/passes.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/passes.py#L54

String statement has no effect (pointless-string-statement)
Implementation design: it just stacks the decorators for the user automatically.
e.g.pass_pipeline = {
"mitigate_with_zne": {"scale_factors": [1, 2, 3]},
"cancel_inverses": {}
}
will just do
fn = mitigate_with_zne(scale_factors=[1, 2, 3])(fn)
fn = cancel_inverses(fn)

Note that since python 3.7, dictionary order is guaranteed to be in insertion order
https://docs.python.org/3/library/stdtypes.html#dict.values
"""
kwargs = copy.copy(locals())
kwargs.pop("fn")

if fn is None:
return functools.partial(pipeline, **kwargs)

if not isinstance(fn, qml.QNode):
raise TypeError(f"A QNode is expected, got the classical function {fn}")

if pass_pipeline is None:
# TODO: design a default peephole pipeline
return fn

API_calls = API_name_to_API_calls()

fn_original_name = fn.__name__
fn_clone = copy.copy(fn)
fn_clone.__name__ = fn_original_name + "_transformed"

# Note: we create wrappers to inject the apply_registered_pass primitive
# In other words, the last API call, aka the last wrapper, will be the outermost primitive
# Therefore when lowering these primitives to mlir (in jax_primitives.py), lower them in reverse order!
for decorator, decorator_args in pass_pipeline.items():
if decorator not in API_calls.keys():
raise RuntimeError(f"{decorator} is not a valid quantum transformation pass")

fn_clone = API_calls[decorator](fn_clone, keep_original=False, **decorator_args)

return fn_clone


def cancel_inverses(fn=None, keep_original=True):
"""
Specify that the ``-removed-chained-self-inverse`` MLIR compiler pass
for cancelling two neighbouring self-inverse
Expand Down Expand Up @@ -137,33 +191,61 @@
if not isinstance(fn, qml.QNode):
raise TypeError(f"A QNode is expected, got the classical function {fn}")

wrapped_qnode_function = fn.func
funcname = fn.__name__
wrapped_qnode_function = fn.func

def wrapper(*args, **kwrags):
# TODO: hint the compiler which qnodes to run the pass on via an func attribute,
# instead of the qnode name. That way the clone can have this attribute and
# the original can just not have it.
# We are not doing this right now and passing by name because this would
# be a discardable attribute (i.e. a user/developer wouldn't know that this
# attribute exists just by looking at qnode's documentation)
# But when we add the full peephole pipeline in the future, the attribute
# could get properly documented.

apply_registered_pass_p.bind(
pass_name="remove-chained-self-inverse",
options=f"func-name={funcname}" + "_cancel_inverses",
)
return wrapped_qnode_function(*args, **kwrags)
if keep_original:

fn_clone = copy.copy(fn)
fn_clone.func = wrapper
fn_clone.__name__ = funcname + "_cancel_inverses"
def wrapper(*args, **kwrags):
# TODO: hint the compiler which qnodes to run the pass on via an func attribute,
# instead of the qnode name. That way the clone can have this attribute and
# the original can just not have it.
# We are not doing this right now and passing by name because this would
# be a discardable attribute (i.e. a user/developer wouldn't know that this
# attribute exists just by looking at qnode's documentation)
# But when we add the full peephole pipeline in the future, the attribute
# could get properly documented.

return fn_clone
if EvaluationContext.is_tracing():
apply_registered_pass_p.bind(
pass_name="remove-chained-self-inverse",
options=f"func-name={funcname}" + "_cancel_inverses",
)
return wrapped_qnode_function(*args, **kwrags)

fn_clone = copy.copy(fn)
fn_clone.func = wrapper
fn_clone.__name__ = funcname + "_cancel_inverses"

return fn_clone

else:

def wrapper(*args, **kwrags):
if EvaluationContext.is_tracing():
apply_registered_pass_p.bind(

Check notice on line 226 in frontend/catalyst/passes.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/passes.py#L226

Missing function or method docstring (missing-function-docstring)
pass_name="remove-chained-self-inverse",
options=f"func-name={funcname}",
)
return wrapped_qnode_function(*args, **kwrags)

fn.func = wrapper

Check notice on line 232 in frontend/catalyst/passes.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/passes.py#L232

Unused argument 'keep_original' (unused-argument)
return fn


def mitigate_with_zne(*args, keep_original=False, **kwrags):
"""
An alias of catalyst.mitigate_with_zne.
See https://docs.pennylane.ai/projects/catalyst/en/stable/code/api/catalyst.mitigate_with_zne.html
"""
return _mitigate_with_zne_api_extensions(*args, **kwrags)


## IMPL and helpers ##
def API_name_to_API_calls():
return {"cancel_inverses": cancel_inverses, "mitigate_with_zne": mitigate_with_zne}


def _inject_transform_named_sequence():
"""
Inject a transform_named_sequence jax primitive.
Expand Down
5 changes: 4 additions & 1 deletion frontend/catalyst/qfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from catalyst.jax_primitives import func_p
from catalyst.jax_tracer import trace_quantum_function
from catalyst.logging import debug_logger
from catalyst.passes import pipeline
from catalyst.tracing.type_signatures import filter_static_args
from catalyst.utils.toml import DeviceCapabilities, ProgramFeatures

Expand Down Expand Up @@ -109,9 +110,11 @@

# pylint: disable=no-member
@debug_logger
def __call__(self, *args, **kwargs):
def __call__(self, pass_pipeline, *args, **kwargs):
assert isinstance(self, qml.QNode)

self = pipeline(pass_pipeline=pass_pipeline)(self)

Check notice on line 116 in frontend/catalyst/qfunc.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/qfunc.py#L116

Invalid assignment to self in method (self-cls-assignment)
paul0403 marked this conversation as resolved.
Show resolved Hide resolved

# Mid-circuit measurement configuration/execution
dynamic_one_shot_called = getattr(self, "_dynamic_one_shot_called", False)
if not dynamic_one_shot_called:
Expand Down
Loading
Loading