Skip to content

Commit

Permalink
Fix adjoint validation with global measurements (#5761)
Browse files Browse the repository at this point in the history
**Context:**
Validation of the adjoint method did not take device wires into account,
leading to the linked issue.

**Description of the Change:**
Include `validate_device_wires` in `_supports_adjoint` in
`DefaultQubit`.

**Benefits:**

**Possible Drawbacks:**

**Related GitHub Issues:**
fixes #5760

[sc-64278]
  • Loading branch information
dwierichs authored May 31, 2024
1 parent f318ec4 commit 544b1de
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 5 deletions.
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,9 @@

<h3>Bug fixes 🐛</h3>

* The validation of the adjoint method in `DefaultQubit` correctly handles device wires now.
[(#5761)](https://github.com/PennyLaneAI/pennylane/pull/5761)

* `QuantumPhaseEstimation.map_wires` on longer modifies the original operation instance.
[(#5698)](https://github.com/PennyLaneAI/pennylane/pull/5698)

Expand Down
5 changes: 3 additions & 2 deletions pennylane/devices/default_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,11 +184,12 @@ def adjoint_observables(obs: qml.operation.Operator) -> bool:
return obs.has_matrix


def _supports_adjoint(circuit):
def _supports_adjoint(circuit, device_wires, device_name):
if circuit is None:
return True

prog = TransformProgram()
prog.add_transform(validate_device_wires, device_wires, name=device_name)
_add_adjoint_transforms(prog)

try:
Expand Down Expand Up @@ -474,7 +475,7 @@ def supports_derivatives(
)

if execution_config.gradient_method in {"adjoint", "best"}:
return _supports_adjoint(circuit=circuit)
return _supports_adjoint(circuit, device_wires=self.wires, device_name=self.name)
return False

@debug_logger
Expand Down
15 changes: 12 additions & 3 deletions tests/devices/default_qubit/test_default_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,15 +127,24 @@ def test_supports_backprop(self):
assert dev.supports_jvp(config) is True
assert dev.supports_vjp(config) is True

def test_supports_adjoint(self):
@pytest.mark.parametrize(
"device_wires, measurement",
[
(None, qml.expval(qml.PauliZ(0))),
(2, qml.expval(qml.PauliZ(0))),
(2, qml.probs()),
(2, qml.probs([0])),
],
)
def test_supports_adjoint(self, device_wires, measurement):
"""Test that DefaultQubit says that it supports adjoint differentiation."""
dev = DefaultQubit()
dev = DefaultQubit(wires=device_wires)
config = ExecutionConfig(gradient_method="adjoint", use_device_gradient=True)
assert dev.supports_derivatives(config) is True
assert dev.supports_jvp(config) is True
assert dev.supports_vjp(config) is True

qs = qml.tape.QuantumScript([], [qml.expval(qml.PauliZ(0))])
qs = qml.tape.QuantumScript([], [measurement])
assert dev.supports_derivatives(config, qs) is True
assert dev.supports_jvp(config, qs) is True
assert dev.supports_vjp(config, qs) is True
Expand Down

0 comments on commit 544b1de

Please sign in to comment.