From 544b1de3361c6ac1f45b88a63fa2cda3768af0b3 Mon Sep 17 00:00:00 2001 From: David Wierichs Date: Fri, 31 May 2024 15:57:25 +0200 Subject: [PATCH] Fix adjoint validation with global measurements (#5761) **Context:** Validation of the adjoint method did not take device wires into account, leading to the linked issue. **Description of the Change:** Include `validate_device_wires` in `_supports_adjoint` in `DefaultQubit`. **Benefits:** **Possible Drawbacks:** **Related GitHub Issues:** fixes #5760 [sc-64278] --- doc/releases/changelog-dev.md | 3 +++ pennylane/devices/default_qubit.py | 5 +++-- tests/devices/default_qubit/test_default_qubit.py | 15 ++++++++++++--- 3 files changed, 18 insertions(+), 5 deletions(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 489fd740d69..a7d698c734d 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -190,6 +190,9 @@

Bug fixes 🐛

+* The validation of the adjoint method in `DefaultQubit` correctly handles device wires now. + [(#5761)](https://github.com/PennyLaneAI/pennylane/pull/5761) + * `QuantumPhaseEstimation.map_wires` on longer modifies the original operation instance. [(#5698)](https://github.com/PennyLaneAI/pennylane/pull/5698) diff --git a/pennylane/devices/default_qubit.py b/pennylane/devices/default_qubit.py index 613f7395646..35dc391a6a5 100644 --- a/pennylane/devices/default_qubit.py +++ b/pennylane/devices/default_qubit.py @@ -184,11 +184,12 @@ def adjoint_observables(obs: qml.operation.Operator) -> bool: return obs.has_matrix -def _supports_adjoint(circuit): +def _supports_adjoint(circuit, device_wires, device_name): if circuit is None: return True prog = TransformProgram() + prog.add_transform(validate_device_wires, device_wires, name=device_name) _add_adjoint_transforms(prog) try: @@ -474,7 +475,7 @@ def supports_derivatives( ) if execution_config.gradient_method in {"adjoint", "best"}: - return _supports_adjoint(circuit=circuit) + return _supports_adjoint(circuit, device_wires=self.wires, device_name=self.name) return False @debug_logger diff --git a/tests/devices/default_qubit/test_default_qubit.py b/tests/devices/default_qubit/test_default_qubit.py index e1fd668f504..4af0d167ebb 100644 --- a/tests/devices/default_qubit/test_default_qubit.py +++ b/tests/devices/default_qubit/test_default_qubit.py @@ -127,15 +127,24 @@ def test_supports_backprop(self): assert dev.supports_jvp(config) is True assert dev.supports_vjp(config) is True - def test_supports_adjoint(self): + @pytest.mark.parametrize( + "device_wires, measurement", + [ + (None, qml.expval(qml.PauliZ(0))), + (2, qml.expval(qml.PauliZ(0))), + (2, qml.probs()), + (2, qml.probs([0])), + ], + ) + def test_supports_adjoint(self, device_wires, measurement): """Test that DefaultQubit says that it supports adjoint differentiation.""" - dev = DefaultQubit() + dev = DefaultQubit(wires=device_wires) config = ExecutionConfig(gradient_method="adjoint", use_device_gradient=True) assert dev.supports_derivatives(config) is True assert dev.supports_jvp(config) is True assert dev.supports_vjp(config) is True - qs = qml.tape.QuantumScript([], [qml.expval(qml.PauliZ(0))]) + qs = qml.tape.QuantumScript([], [measurement]) assert dev.supports_derivatives(config, qs) is True assert dev.supports_jvp(config, qs) is True assert dev.supports_vjp(config, qs) is True