Skip to content

Commit 1690f8c

Browse files
Remove "experimental_capture" kwarg from "qjit" (#1657)
**Context:** When dealing with transforms, the `experimental_capture` keyword from `qjit` fails at capturing them because they are defined before the capture functionality has been activated. To circumvent this, program capture has to be enabled manually beforehand, which defies the purpose of the mentioned keyword. We propose using only the PL program capture enabling/disabling mechanism across the whole ecosystem to prevent such cases to happen. **Description of the Change:** Removed the `experimental_capture` keyword from its `qjit` function in favor of a unified program capture behavior. [(#1657)](#1657) Program capture has to be enabled before the definition of the function to be qjitted. For AOT compilation, program capture can be disabled right after the qjit usage and before execution. ```python import pennylane as qml from catalyst import qjit dev = qml.device("lightning.qubit", wires=1) qml.capture.enable() @qjit() @qml.qnode(dev) def circuit(x: float): qml.Hadamard(0) qml.CNOT([0, 1]) return qml.expval(qml.Z(0)) qml.capture.disable() circuit(0.1) ``` But for JIT compilation, program capture cannot be disabled before execution, otherwise the capture will not take place: ```python import pennylane as qml from catalyst import qjit dev = qml.device("lightning.qubit", wires=1) qml.capture.enable() @qjit() @qml.qnode(dev) def circuit(x): qml.Hadamard(0) qml.CNOT([0, 1]) return qml.expval(qml.Z(0)) circuit(0.1) qml.capture.disable() ``` **Benefits:** - Unified program capture behavior. - No longer necessary to use a wrapper function around a circuit with transforms: Before: ```python @qjit(experimental_capture=True) def wrapper(): @qml.transforms.merge_amplitude_embedding @qml.qnode(qml.device(backend, wires=2)) def captured_circuit(): qml.AmplitudeEmbedding(jnp.array([0.0, 1.0]), wires=0) qml.AmplitudeEmbedding(jnp.array([0.0, 1.0]), wires=1) return qml.expval(qml.PauliZ(0)) ``` After: ```python qml.capture.enable() @qjit() @qml.transforms.merge_amplitude_embedding @qml.qnode(qml.device(backend, wires=2)) def captured_circuit(): qml.AmplitudeEmbedding(jnp.array([0.0, 1.0]), wires=0) qml.AmplitudeEmbedding(jnp.array([0.0, 1.0]), wires=1) return qml.expval(qml.PauliZ(0)) ``` **Possible Drawbacks:** Users might find it difficult to understand where exactly program capture should be enabled/disabled. [sc-89121] --------- Co-authored-by: Isaac De Vlugt <34751083+isaacdevlugt@users.noreply.github.com>
1 parent dc51613 commit 1690f8c

File tree

5 files changed

+836
-276
lines changed

5 files changed

+836
-276
lines changed

doc/releases/changelog-dev.md

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,33 @@
1515

1616
<h3>Breaking changes 💔</h3>
1717

18+
* Catalyst has removed the `experimental_capture` keyword from the `qjit` decorator in favour of
19+
unified behaviour with PennyLane.
20+
[(#1657)](https://github.com/PennyLaneAI/catalyst/pull/1657)
21+
22+
Instead of enabling program capture with Catalyst via `qjit(experimental_capture=True)`, program capture
23+
can be enabled via the global toggle `qml.capture.enable()`:
24+
25+
```python
26+
import pennylane as qml
27+
from catalyst import qjit
28+
29+
dev = qml.device("lightning.qubit", wires=2)
30+
31+
qml.capture.enable()
32+
33+
@qjit
34+
@qml.qnode(dev)
35+
def circuit(x):
36+
qml.Hadamard(0)
37+
qml.CNOT([0, 1])
38+
return qml.expval(qml.Z(0))
39+
40+
circuit(0.1)
41+
```
42+
43+
Disabling program capture can be done with `qml.capture.disable()`.
44+
1845
<h3>Deprecations 👋</h3>
1946

2047
<h3>Bug fixes 🐛</h3>

frontend/catalyst/from_plxpr.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import pennylane as qml
2525
from jax.extend.linear_util import wrap_init
2626
from jax.interpreters.partial_eval import convert_constvars_jaxpr
27-
from pennylane.capture import PlxprInterpreter, disable, enable, enabled, qnode_prim
27+
from pennylane.capture import PlxprInterpreter, qnode_prim
2828
from pennylane.capture.expand_transforms import ExpandTransformsInterpreter
2929
from pennylane.capture.primitives import cond_prim as plxpr_cond_prim
3030
from pennylane.capture.primitives import for_loop_prim as plxpr_for_loop_prim
@@ -739,18 +739,9 @@ def trace_from_pennylane(fn, static_argnums, abstracted_axes, sig, kwargs):
739739
"abstracted_axes": abstracted_axes,
740740
}
741741

742-
if enabled():
743-
capture_on = True
744-
else:
745-
capture_on = False
746-
enable()
747-
748742
args = sig
749-
try:
750-
plxpr, out_type, out_treedef = make_jaxpr2(fn, **make_jaxpr_kwargs)(*args, **kwargs)
751-
jaxpr = from_plxpr(plxpr)(*args, **kwargs)
752-
finally:
753-
if not capture_on:
754-
disable()
743+
744+
plxpr, out_type, out_treedef = make_jaxpr2(fn, **make_jaxpr_kwargs)(*args, **kwargs)
745+
jaxpr = from_plxpr(plxpr)(*args, **kwargs)
755746

756747
return jaxpr, out_type, out_treedef, sig

frontend/catalyst/jit.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,6 @@ def qjit(
9191
abstracted_axes=None,
9292
disable_assertions=False,
9393
seed=None,
94-
experimental_capture=False,
9594
circuit_transform_pipeline=None,
9695
pass_plugins=None,
9796
dialect_plugins=None,
@@ -151,9 +150,6 @@ def qjit(
151150
:func:`qml.sample() <pennylane.sample>`, :func:`qml.counts() <pennylane.counts>`,
152151
:func:`qml.probs() <pennylane.probs>`, :func:`qml.expval() <pennylane.expval>`,
153152
:func:`qml.var() <pennylane.var>`.
154-
experimental_capture (bool): If set to ``True``, the qjit decorator
155-
will use PennyLane's experimental program capture capabilities
156-
to capture the decorated function for compilation.
157153
circuit_transform_pipeline (Optional[dict[str, dict[str, str]]]):
158154
A dictionary that specifies the quantum circuit transformation pass pipeline order,
159155
and optionally arguments for each pass in the pipeline. Keys of this dictionary
@@ -715,7 +711,7 @@ def capture(self, args, **kwargs):
715711
dynamic_sig = get_abstract_signature(dynamic_args)
716712
full_sig = merge_static_args(dynamic_sig, args, static_argnums)
717713

718-
if self.compile_options.experimental_capture:
714+
if qml.capture.enabled():
719715
return trace_from_pennylane(
720716
self.user_function, static_argnums, abstracted_axes, full_sig, kwargs
721717
)

frontend/catalyst/pipelines.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,6 @@ class CompileOptions:
6868
disable_assertions (Optional[bool]): disables all assertions. Default is ``False``.
6969
seed (Optional[int]) : the seed for random operations in a qjit call.
7070
Default is None.
71-
experimental_capture (bool): If set to ``True``,
72-
use PennyLane's experimental program capture capabilities
73-
to capture the function for compilation.
7471
circuit_transform_pipeline (Optional[dict[str, dict[str, str]]]):
7572
A dictionary that specifies the quantum circuit transformation pass pipeline order,
7673
and optionally arguments for each pass in the pipeline.
@@ -94,7 +91,6 @@ class CompileOptions:
9491
checkpoint_stage: Optional[str] = ""
9592
disable_assertions: Optional[bool] = False
9693
seed: Optional[int] = None
97-
experimental_capture: Optional[bool] = False
9894
circuit_transform_pipeline: Optional[dict[str, dict[str, str]]] = None
9995
pass_plugins: Optional[Set[Path]] = None
10096
dialect_plugins: Optional[Set[Path]] = None

0 commit comments

Comments
 (0)