Skip to content

Commit

Permalink
Merge branch 'master' into fix_fermionic_op_jax
Browse files Browse the repository at this point in the history
  • Loading branch information
austingmhuang authored Oct 18, 2024
2 parents 01b5f30 + 8a46e32 commit d83cef7
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 1 deletion.
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,9 @@

<h3>Bug fixes 🐛</h3>

* `default.tensor` can now handle mid circuit measurements via the deferred measurement principle.
[(#6408)](https://github.com/PennyLaneAI/pennylane/pull/6408)

* The `validate_device_wires` transform now raises an error if abstract wires are provided.
[(#6405)](https://github.com/PennyLaneAI/pennylane/pull/6405)

Expand Down
13 changes: 12 additions & 1 deletion pennylane/devices/default_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,12 @@ def __init__(
# that access it as soon as the device is created before running a circuit.
self._quimb_circuit = self._initial_quimb_circuit(self.wires)

shots = kwargs.pop("shots", None)
if shots is not None:
raise qml.DeviceError(
"default.tensor only supports analytic simulations with shots=None."
)

for arg in kwargs:
if arg not in self._device_options:
raise TypeError(
Expand Down Expand Up @@ -571,14 +577,18 @@ def _setup_execution_config(
Update the execution config with choices for how the device should be used and the device options.
"""
# TODO: add options for gradients next quarter

updated_values = {}

new_device_options = dict(config.device_options)
for option in self._device_options:
if option not in new_device_options:
new_device_options[option] = getattr(self, f"_{option}", None)

if config.mcm_config.mcm_method not in {None, "deferred"}:
raise qml.DeviceError(
f"{self.name} only supports the deferred measurement principle, not {config.mcm_config.mcm_method}"
)

return replace(config, **updated_values, device_options=new_device_options)

def preprocess(
Expand Down Expand Up @@ -610,6 +620,7 @@ def preprocess(
program.add_transform(validate_measurements, name=self.name)
program.add_transform(validate_observables, accepted_observables, name=self.name)
program.add_transform(validate_device_wires, self._wires, name=self.name)
program.add_transform(qml.defer_measurements, device=self)
program.add_transform(
decompose,
stopping_condition=stopping_condition,
Expand Down
55 changes: 55 additions & 0 deletions tests/devices/default_tensor/test_default_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,19 @@ def test_kahypar_warning_not_raised(recwarn):
assert len(recwarn) == 0


def test_passing_shots_None():
"""Test that passing shots=None on initialization works without error."""
dev = qml.device("default.tensor", shots=None)
assert dev.shots == qml.measurements.Shots(None)


def test_passing_finite_shots_error():
"""Test that an error is raised if finite shots are passed on initialization."""

with pytest.raises(qml.DeviceError, match=r"only supports analytic simulations"):
qml.device("default.tensor", shots=10)


@pytest.mark.parametrize("method", ["mps", "tn"])
class TestSupportedGatesAndObservables:
"""Test that the DefaultTensor device supports all gates and observables that it claims to support."""
Expand Down Expand Up @@ -509,3 +522,45 @@ def circuit():
state = circuit()
assert isinstance(state, TensorLike)
assert len(state) == 2 ** (2 * num_orbitals + 1)


class TestMCMs:
"""Test that default.tensor can handle mid circuit measurements."""

@pytest.mark.parametrize("mcm_method", ("one-shot", "tree-traversal"))
def test_error_on_unsupported_mcm_method(self, mcm_method):
"""Test that an error is raised on unsupported mcm methods."""

mcm_config = qml.devices.MCMConfig(mcm_method=mcm_method)
config = qml.devices.ExecutionConfig(mcm_config=mcm_config)
with pytest.raises(
qml.DeviceError, match=r"only supports the deferred measurement principle."
):
qml.device("default.tensor").preprocess(config)

def test_simple_mcm_present(self):
"""Test that the device can execute a circuit with a mid circuit measurement."""

dev = qml.device("default.tensor")

@qml.qnode(dev)
def circuit():
qml.measure(0)
return qml.expval(qml.Z(0))

res = circuit()
assert qml.math.allclose(res, 1)

def test_mcm_conditional(self):
"""Test that the device execute a circuit with an MCM and a conditional."""

dev = qml.device("default.tensor")

@qml.qnode(dev)
def circuit(x):
m0 = qml.measure(0)
qml.cond(~m0, qml.RX)(x, 0)
return qml.expval(qml.Z(0))

res = circuit(0.5)
assert qml.math.allclose(res, np.cos(0.5))

0 comments on commit d83cef7

Please sign in to comment.