diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 6b386172d53..a3842494416 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -286,6 +286,9 @@

Bug fixes 🐛

+* `default.tensor` can now handle mid circuit measurements via the deferred measurement principle. + [(#6408)](https://github.com/PennyLaneAI/pennylane/pull/6408) + * The `validate_device_wires` transform now raises an error if abstract wires are provided. [(#6405)](https://github.com/PennyLaneAI/pennylane/pull/6405) diff --git a/pennylane/devices/default_tensor.py b/pennylane/devices/default_tensor.py index e80001e76f5..e93a2110028 100644 --- a/pennylane/devices/default_tensor.py +++ b/pennylane/devices/default_tensor.py @@ -414,6 +414,12 @@ def __init__( # that access it as soon as the device is created before running a circuit. self._quimb_circuit = self._initial_quimb_circuit(self.wires) + shots = kwargs.pop("shots", None) + if shots is not None: + raise qml.DeviceError( + "default.tensor only supports analytic simulations with shots=None." + ) + for arg in kwargs: if arg not in self._device_options: raise TypeError( @@ -571,7 +577,6 @@ def _setup_execution_config( Update the execution config with choices for how the device should be used and the device options. """ # TODO: add options for gradients next quarter - updated_values = {} new_device_options = dict(config.device_options) @@ -579,6 +584,11 @@ def _setup_execution_config( if option not in new_device_options: new_device_options[option] = getattr(self, f"_{option}", None) + if config.mcm_config.mcm_method not in {None, "deferred"}: + raise qml.DeviceError( + f"{self.name} only supports the deferred measurement principle, not {config.mcm_config.mcm_method}" + ) + return replace(config, **updated_values, device_options=new_device_options) def preprocess( @@ -610,6 +620,7 @@ def preprocess( program.add_transform(validate_measurements, name=self.name) program.add_transform(validate_observables, accepted_observables, name=self.name) program.add_transform(validate_device_wires, self._wires, name=self.name) + program.add_transform(qml.defer_measurements, device=self) program.add_transform( decompose, stopping_condition=stopping_condition, diff --git a/tests/devices/default_tensor/test_default_tensor.py b/tests/devices/default_tensor/test_default_tensor.py index 203680053d4..0da129bd60a 100644 --- a/tests/devices/default_tensor/test_default_tensor.py +++ b/tests/devices/default_tensor/test_default_tensor.py @@ -270,6 +270,19 @@ def test_kahypar_warning_not_raised(recwarn): assert len(recwarn) == 0 +def test_passing_shots_None(): + """Test that passing shots=None on initialization works without error.""" + dev = qml.device("default.tensor", shots=None) + assert dev.shots == qml.measurements.Shots(None) + + +def test_passing_finite_shots_error(): + """Test that an error is raised if finite shots are passed on initialization.""" + + with pytest.raises(qml.DeviceError, match=r"only supports analytic simulations"): + qml.device("default.tensor", shots=10) + + @pytest.mark.parametrize("method", ["mps", "tn"]) class TestSupportedGatesAndObservables: """Test that the DefaultTensor device supports all gates and observables that it claims to support.""" @@ -509,3 +522,45 @@ def circuit(): state = circuit() assert isinstance(state, TensorLike) assert len(state) == 2 ** (2 * num_orbitals + 1) + + +class TestMCMs: + """Test that default.tensor can handle mid circuit measurements.""" + + @pytest.mark.parametrize("mcm_method", ("one-shot", "tree-traversal")) + def test_error_on_unsupported_mcm_method(self, mcm_method): + """Test that an error is raised on unsupported mcm methods.""" + + mcm_config = qml.devices.MCMConfig(mcm_method=mcm_method) + config = qml.devices.ExecutionConfig(mcm_config=mcm_config) + with pytest.raises( + qml.DeviceError, match=r"only supports the deferred measurement principle." + ): + qml.device("default.tensor").preprocess(config) + + def test_simple_mcm_present(self): + """Test that the device can execute a circuit with a mid circuit measurement.""" + + dev = qml.device("default.tensor") + + @qml.qnode(dev) + def circuit(): + qml.measure(0) + return qml.expval(qml.Z(0)) + + res = circuit() + assert qml.math.allclose(res, 1) + + def test_mcm_conditional(self): + """Test that the device execute a circuit with an MCM and a conditional.""" + + dev = qml.device("default.tensor") + + @qml.qnode(dev) + def circuit(x): + m0 = qml.measure(0) + qml.cond(~m0, qml.RX)(x, 0) + return qml.expval(qml.Z(0)) + + res = circuit(0.5) + assert qml.math.allclose(res, np.cos(0.5))