diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index c107e74e97f..3690638820b 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -88,6 +88,11 @@ * Sets up the framework for the development of an `assert_equal` function for testing operator comparison. [(#5634)](https://github.com/PennyLaneAI/pennylane/pull/5634) +* `qml.sample` can now be used on Boolean values representing mid-circuit measurement results in + traced quantum functions. This feature is used with Catalyst to enable the pattern + `m = measure(0); qml.sample(m)`. + [(#5673)](https://github.com/PennyLaneAI/pennylane/pull/5673) + * PennyLane operators can now automatically be captured as instructions in JAXPR. See the experimental `capture` module for more information. [(#5511)](https://github.com/PennyLaneAI/pennylane/pull/5511) @@ -202,6 +207,7 @@ Ahmed Darwish, Isaac De Vlugt, Pietropaolo Frisoni, Emiliano Godinez, +David Ittah, Soran Jahangiri, Korbinian Kottmann, Christina Lee, diff --git a/pennylane/measurements/measurements.py b/pennylane/measurements/measurements.py index 167cb26872d..8e8ac4f34eb 100644 --- a/pennylane/measurements/measurements.py +++ b/pennylane/measurements/measurements.py @@ -23,6 +23,7 @@ from typing import Optional, Sequence, Tuple, Union import pennylane as qml +from pennylane.math.utils import is_abstract from pennylane.operation import DecompositionUndefinedError, EigvalsUndefinedError, Operator from pennylane.pytrees import register_pytree from pennylane.typing import TensorLike @@ -162,6 +163,9 @@ def __init__( # Cast sequence of measurement values to list self.mv = obs if getattr(obs, "name", None) == "MeasurementValue" else list(obs) self.obs = None + elif is_abstract(obs): # Catalyst program with qml.sample(m, wires=i) + self.mv = obs + self.obs = None else: self.obs = obs self.mv = None @@ -306,7 +310,7 @@ def wires(self): This is the union of all the Wires objects of the measurement. """ - if self.mv is not None: + if self.mv is not None and not is_abstract(self.mv): if isinstance(self.mv, list): return qml.wires.Wires.all_wires([m.wires for m in self.mv]) return self.mv.wires diff --git a/tests/measurements/test_sample.py b/tests/measurements/test_sample.py index 8965d746628..c54b78f620f 100644 --- a/tests/measurements/test_sample.py +++ b/tests/measurements/test_sample.py @@ -499,6 +499,24 @@ def circuit(x): ) +@pytest.mark.jax +def test_sample_with_boolean_tracer(): + """Test that qml.sample can be used with Catalyst measurement values (Boolean tracer).""" + import jax + + def fun(b): + mp = qml.sample(b) + + assert mp.obs is None + assert isinstance(mp.mv, jax.interpreters.partial_eval.DynamicJaxprTracer) + assert mp.mv.dtype == bool + assert mp.mv.shape == () + assert isinstance(mp.wires, qml.wires.Wires) + assert mp.wires == () + + jax.make_jaxpr(fun)(True) + + @pytest.mark.jax @pytest.mark.parametrize( "obs", diff --git a/tests/test_compiler.py b/tests/test_compiler.py index 904fbad2964..0df35a567e2 100644 --- a/tests/test_compiler.py +++ b/tests/test_compiler.py @@ -716,3 +716,24 @@ def f(x): CompileError, match="Pennylane does not support the VJP function without QJIT." ): vjp(x, dy) + + +class TestCatalystSample: + """Test qml.sample with Catalyst.""" + + @pytest.mark.xfail(reason="requires simultaneous catalyst pr") + def test_sample_measure(self): + """Test that qml.sample can be used with catalyst.measure.""" + + dev = qml.device("lightning.qubit", wires=1, shots=1) + + @qml.qjit + @qml.qnode(dev) + def circuit(x): + qml.RY(x, wires=0) + m = catalyst.measure(0) + qml.PauliX(0) + return qml.sample(m) + + assert circuit(0.0) == 0 + assert circuit(jnp.pi) == 1