Skip to content

Commit

Permalink
[BugFix] - PennyLane raises an informative error if running circuit(s…
Browse files Browse the repository at this point in the history
…) with dynamic tracers as wires (#6405)

**Context:** Pennylane provides incorrect results or non-informative
errors (depending on the circumstances) when working with `@jax.jit` and
dynamic wires (that is, JAX tracers) at the same time. For example:

```
dev = qml.device("default.qubit")

@jax.jit
@qml.qnode(dev)
def circuit(input_wires):
    qml.Hadamard(input_wires[1])
    return qml.probs(wires=[0, 1])

circuit([0, 1])

```

**Description of the Change:** Pennylane raises an informative error in
the `validate_device_wires` transform to inform the user that abstract
wires are not currently supported. We cannot raise this error in the
`Wires` class since this would be a problem for Catalyst. This solution
should not (hopefully) cause issues in other repositories.

**Benefits:** This prevents non-informative errors and, most
importantly, wrong results.

**Possible Drawbacks:** This change could potentially break existing
code(s) that somehow used abstract wires and for some reason they worked
fine.

**Related GitHub Issues:** #6380 

**Related ShortCut Stories:** [sc-75756]
  • Loading branch information
PietropaoloFrisoni authored Oct 17, 2024
1 parent 35a5996 commit 056bb92
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 1 deletion.
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,9 @@

<h3>Bug fixes 🐛</h3>

* The `validate_device_wires` transform now raises an error if abstract wires are provided.
[(#6405)](https://github.com/PennyLaneAI/pennylane/pull/6405)

* Fixes `qml.math.expand_matrix` for qutrit and arbitrary qudit operators.
[(#6398)](https://github.com/PennyLaneAI/pennylane/pull/6398/)

Expand Down
16 changes: 15 additions & 1 deletion pennylane/devices/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,23 @@ def validate_device_wires(
The unaltered input circuit. The output type is explained in :func:`qml.transform <pennylane.transform>`.
Raises:
WireError: if the tape has a wire not present in the provided wires.
WireError: if the tape has a wire not present in the provided wires, or if abstract wires are present.
"""

if any(qml.math.is_abstract(w) for w in tape.wires):
raise WireError(
f"Cannot run circuit(s) on {name} as abstract wires are present in the tape: {tape.wires}. "
f"Abstract wires are not yet supported."
)

if wires:

if any(qml.math.is_abstract(w) for w in wires):
raise WireError(
f"Cannot run circuit(s) on {name} as abstract wires are present in the device: {wires}. "
f"Abstract wires are not yet supported."
)

if extra_wires := set(tape.wires) - set(wires):
raise WireError(
f"Cannot run circuit(s) on {name} as they contain wires "
Expand Down
1 change: 1 addition & 0 deletions pennylane/wires.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def _process(wires):
# Note, this is not the same as `isinstance(wires, Iterable)` which would
# pass for 0-dim numpy arrays that cannot be iterated over.
tuple_of_wires = tuple(wires)

except TypeError:
# if not iterable, interpret as single wire label
try:
Expand Down
31 changes: 31 additions & 0 deletions tests/devices/test_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,37 @@ def test_fill_in_wires(self):
assert batch[0].operations == tape1.operations
assert batch[0].shots == tape1.shots

@pytest.mark.jax
def test_error_abstract_wires_tape(self):
"""Tests that an error is raised if abstract wires are present in the tape."""

import jax

def jit_wires_tape(wires):
tape_with_abstract_wires = QuantumScript([qml.CNOT(wires=qml.wires.Wires(wires))])
validate_device_wires(tape_with_abstract_wires, name="fictional_device")

with pytest.raises(
qml.wires.WireError,
match="on fictional_device as abstract wires are present in the tape",
):
jax.jit(jit_wires_tape)([0, 1])

@pytest.mark.jax
def test_error_abstract_wires_dev(self):
"""Tests that an error is raised if abstract wires are present in the device."""

import jax

def jit_wires_dev(wires):
validate_device_wires(QuantumScript([]), wires=wires, name="fictional_device")

with pytest.raises(
qml.wires.WireError,
match="on fictional_device as abstract wires are present in the device",
):
jax.jit(jit_wires_dev)([0, 1])


class TestDecomposeValidation:
"""Unit tests for helper functions in qml.devices.qubit.preprocess"""
Expand Down

0 comments on commit 056bb92

Please sign in to comment.