diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index 6b386172d53..a3842494416 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -286,6 +286,9 @@
Bug fixes 🐛
+* `default.tensor` can now handle mid circuit measurements via the deferred measurement principle.
+ [(#6408)](https://github.com/PennyLaneAI/pennylane/pull/6408)
+
* The `validate_device_wires` transform now raises an error if abstract wires are provided.
[(#6405)](https://github.com/PennyLaneAI/pennylane/pull/6405)
diff --git a/pennylane/devices/default_tensor.py b/pennylane/devices/default_tensor.py
index e80001e76f5..e93a2110028 100644
--- a/pennylane/devices/default_tensor.py
+++ b/pennylane/devices/default_tensor.py
@@ -414,6 +414,12 @@ def __init__(
# that access it as soon as the device is created before running a circuit.
self._quimb_circuit = self._initial_quimb_circuit(self.wires)
+ shots = kwargs.pop("shots", None)
+ if shots is not None:
+ raise qml.DeviceError(
+ "default.tensor only supports analytic simulations with shots=None."
+ )
+
for arg in kwargs:
if arg not in self._device_options:
raise TypeError(
@@ -571,7 +577,6 @@ def _setup_execution_config(
Update the execution config with choices for how the device should be used and the device options.
"""
# TODO: add options for gradients next quarter
-
updated_values = {}
new_device_options = dict(config.device_options)
@@ -579,6 +584,11 @@ def _setup_execution_config(
if option not in new_device_options:
new_device_options[option] = getattr(self, f"_{option}", None)
+ if config.mcm_config.mcm_method not in {None, "deferred"}:
+ raise qml.DeviceError(
+ f"{self.name} only supports the deferred measurement principle, not {config.mcm_config.mcm_method}"
+ )
+
return replace(config, **updated_values, device_options=new_device_options)
def preprocess(
@@ -610,6 +620,7 @@ def preprocess(
program.add_transform(validate_measurements, name=self.name)
program.add_transform(validate_observables, accepted_observables, name=self.name)
program.add_transform(validate_device_wires, self._wires, name=self.name)
+ program.add_transform(qml.defer_measurements, device=self)
program.add_transform(
decompose,
stopping_condition=stopping_condition,
diff --git a/tests/devices/default_tensor/test_default_tensor.py b/tests/devices/default_tensor/test_default_tensor.py
index 203680053d4..0da129bd60a 100644
--- a/tests/devices/default_tensor/test_default_tensor.py
+++ b/tests/devices/default_tensor/test_default_tensor.py
@@ -270,6 +270,19 @@ def test_kahypar_warning_not_raised(recwarn):
assert len(recwarn) == 0
+def test_passing_shots_None():
+ """Test that passing shots=None on initialization works without error."""
+ dev = qml.device("default.tensor", shots=None)
+ assert dev.shots == qml.measurements.Shots(None)
+
+
+def test_passing_finite_shots_error():
+ """Test that an error is raised if finite shots are passed on initialization."""
+
+ with pytest.raises(qml.DeviceError, match=r"only supports analytic simulations"):
+ qml.device("default.tensor", shots=10)
+
+
@pytest.mark.parametrize("method", ["mps", "tn"])
class TestSupportedGatesAndObservables:
"""Test that the DefaultTensor device supports all gates and observables that it claims to support."""
@@ -509,3 +522,45 @@ def circuit():
state = circuit()
assert isinstance(state, TensorLike)
assert len(state) == 2 ** (2 * num_orbitals + 1)
+
+
+class TestMCMs:
+ """Test that default.tensor can handle mid circuit measurements."""
+
+ @pytest.mark.parametrize("mcm_method", ("one-shot", "tree-traversal"))
+ def test_error_on_unsupported_mcm_method(self, mcm_method):
+ """Test that an error is raised on unsupported mcm methods."""
+
+ mcm_config = qml.devices.MCMConfig(mcm_method=mcm_method)
+ config = qml.devices.ExecutionConfig(mcm_config=mcm_config)
+ with pytest.raises(
+ qml.DeviceError, match=r"only supports the deferred measurement principle."
+ ):
+ qml.device("default.tensor").preprocess(config)
+
+ def test_simple_mcm_present(self):
+ """Test that the device can execute a circuit with a mid circuit measurement."""
+
+ dev = qml.device("default.tensor")
+
+ @qml.qnode(dev)
+ def circuit():
+ qml.measure(0)
+ return qml.expval(qml.Z(0))
+
+ res = circuit()
+ assert qml.math.allclose(res, 1)
+
+ def test_mcm_conditional(self):
+ """Test that the device execute a circuit with an MCM and a conditional."""
+
+ dev = qml.device("default.tensor")
+
+ @qml.qnode(dev)
+ def circuit(x):
+ m0 = qml.measure(0)
+ qml.cond(~m0, qml.RX)(x, 0)
+ return qml.expval(qml.Z(0))
+
+ res = circuit(0.5)
+ assert qml.math.allclose(res, np.cos(0.5))