Skip to content

Commit

Permalink
Improve support for Torch and Jax with dynamic_one_shot (#5672)
Browse files Browse the repository at this point in the history
**Context:**
Opened in favour of #5630. Bug fix for #5442. This PR updates
`dynamic_one_shot` so that it has better compatibility with the `torch`
and `jax` interfaces.

**Description of the Change:**
* Change casting method from `array.astype()` to `qml.math.cast` in the
`apply_operation` dispatch for `MidMeasureMP`.
* Update usage of `qml.math` in `dynamic_one_shot`.
* When using `qml.counts`, cast results to ints before converting to
strings for lists of MCM values and floats for single MCM values. This
is needed because jax arrays are not hashable, and the hash of torch
tensors seems to be independent of the value(s) stored inside it. Thus,
neither can be used as keys for dictionaries.

**Benefits:**
Better interface support with `dynamic_one_shot`.

**Possible Drawbacks:**

**Related GitHub Issues:**

---------

Co-authored-by: Vincent Michaud-Rioux <vincent.michaud-rioux@xanadu.ai>
Co-authored-by: Vincent Michaud-Rioux <vincentm@nanoacademic.com>
Co-authored-by: Christina Lee <christina@xanadu.ai>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Astral Cai <astral.cai@xanadu.ai>
Co-authored-by: David Wierichs <david.wierichs@xanadu.ai>
Co-authored-by: lillian542 <38584660+lillian542@users.noreply.github.com>
Co-authored-by: Pietropaolo Frisoni <pietropaolo.frisoni@xanadu.ai>
Co-authored-by: Korbinian Kottmann <43949391+Qottmann@users.noreply.github.com>
Co-authored-by: Jay Soni <jbsoni@uwaterloo.ca>
Co-authored-by: Guillermo Alonso-Linaje <65235481+KetpuntoG@users.noreply.github.com>
Co-authored-by: Thomas R. Bromley <49409390+trbromley@users.noreply.github.com>
Co-authored-by: Isaac De Vlugt <34751083+isaacdevlugt@users.noreply.github.com>
Co-authored-by: Diksha Dhawan <40900030+ddhawan11@users.noreply.github.com>
Co-authored-by: Isaac De Vlugt <isaacdevlugt@gmail.com>
Co-authored-by: Diego <67476785+DSGuala@users.noreply.github.com>
Co-authored-by: trbromley <brotho02@gmail.com>
Co-authored-by: erick-xanadu <110487834+erick-xanadu@users.noreply.github.com>
Co-authored-by: David Ittah <dime10@users.noreply.github.com>
Co-authored-by: soranjh <40344468+soranjh@users.noreply.github.com>
  • Loading branch information
21 people authored May 24, 2024
1 parent fbc2a39 commit 59a1e05
Show file tree
Hide file tree
Showing 7 changed files with 244 additions and 17 deletions.
4 changes: 4 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,9 @@

<h3>Bug fixes 🐛</h3>

* The `dynamic_one_shot` transform now has expanded support for the `jax` and `torch` interfaces.
[(#5672)](https://github.com/PennyLaneAI/pennylane/pull/5672)

* The decomposition of `StronglyEntanglingLayers` is now compatible with broadcasting.
[(#5716)](https://github.com/PennyLaneAI/pennylane/pull/5716)

Expand Down Expand Up @@ -213,5 +216,6 @@ Korbinian Kottmann,
Christina Lee,
Vincent Michaud-Rioux,
Lee James O'Riordan,
Mudit Pandey,
Kenya Sakka,
David Wierichs.
6 changes: 4 additions & 2 deletions pennylane/devices/qubit/apply_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,8 +330,10 @@ def binomial_fn(n, p):
# to reset enables jax.jit and prevents it from using Python callbacks
element = op.reset and sample == 1
matrix = qml.math.array(
[[(element + 1) % 2, (element) % 2], [(element) % 2, (element + 1) % 2]], like=interface
).astype(float)
[[(element + 1) % 2, (element) % 2], [(element) % 2, (element + 1) % 2]],
like=interface,
dtype=float,
)
state = apply_operation(
qml.QubitUnitary(matrix, wire), state, is_state_batched=is_state_batched, debugger=debugger
)
Expand Down
2 changes: 1 addition & 1 deletion pennylane/devices/qubit/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def simulate(
trainable_params=circuit.trainable_params,
)
keys = jax_random_split(prng_key, num=circuit.shots.total_shots)
if qml.math.get_deep_interface(circuit.data) == "jax":
if qml.math.get_deep_interface(circuit.data) == "jax" and prng_key is not None:
# pylint: disable=import-outside-toplevel
import jax

Expand Down
1 change: 1 addition & 0 deletions pennylane/math/single_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ def _take_autograd(tensor, indices, axis=None):
ar.autoray._SUBMODULE_ALIASES["tensorflow", "isclose"] = "tensorflow.experimental.numpy"
ar.autoray._SUBMODULE_ALIASES["tensorflow", "atleast_1d"] = "tensorflow.experimental.numpy"
ar.autoray._SUBMODULE_ALIASES["tensorflow", "all"] = "tensorflow.experimental.numpy"
ar.autoray._SUBMODULE_ALIASES["tensorflow", "ravel"] = "tensorflow.experimental.numpy"
ar.autoray._SUBMODULE_ALIASES["tensorflow", "vstack"] = "tensorflow.experimental.numpy"

tf_fft_functions = [
Expand Down
28 changes: 19 additions & 9 deletions pennylane/transforms/dynamic_one_shot.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ def measurement_with_no_shots(measurement):
)

interface = qml.math.get_deep_interface(circuit.data)
interface = "numpy" if interface == "builtins" else interface

all_mcms = [op for op in aux_tapes[0].operations if is_mcm(op)]
n_mcms = len(all_mcms)
Expand All @@ -243,10 +244,13 @@ def measurement_with_no_shots(measurement):
mcm_samples = qml.math.array(
[[res] if single_measurement else res[-n_mcms::] for res in results], like=interface
)
has_postselect = qml.math.array([op.postselect is not None for op in all_mcms]).reshape((1, -1))
# Can't use boolean dtype array with tf, hence why conditionally setting items to 0 or 1
has_postselect = qml.math.array(
[[int(op.postselect is not None) for op in all_mcms]], like=interface
)
postselect = qml.math.array(
[0 if op.postselect is None else op.postselect for op in all_mcms]
).reshape((1, -1))
[[0 if op.postselect is None else op.postselect for op in all_mcms]], like=interface
)
is_valid = qml.math.all(mcm_samples * has_postselect == postselect, axis=1)
has_valid = qml.math.any(is_valid)
mid_meas = [op for op in circuit.operations if is_mcm(op)]
Expand All @@ -268,7 +272,12 @@ def measurement_with_no_shots(measurement):
meas = measurement_with_no_shots(m)
m_count += 1
else:
result = qml.math.array([res[m_count] for res in results], like=interface)
result = [res[m_count] for res in results]
if not isinstance(m, CountsMP):
# We don't need to cast to arrays when using qml.counts. qml.math.array is not viable
# as it assumes all elements of the input are of builtin python types and not belonging
# to any particular interface
result = qml.math.stack(result, like=interface)
meas = gather_non_mcm(m, result, is_valid)
m_count += 1
if isinstance(m, SampleMP):
Expand All @@ -292,7 +301,9 @@ def gather_non_mcm(circuit_measurement, measurement, is_valid):
if isinstance(circuit_measurement, CountsMP):
tmp = Counter()
for i, d in enumerate(measurement):
tmp.update(dict((k, v * is_valid[i]) for k, v in d.items()))
tmp.update(
dict((k if isinstance(k, str) else float(k), v * is_valid[i]) for k, v in d.items())
)
tmp = Counter({k: v for k, v in tmp.items() if v > 0})
return dict(sorted(tmp.items()))
if isinstance(circuit_measurement, ExpectationMP):
Expand Down Expand Up @@ -341,14 +352,13 @@ def gather_mcm(measurement, samples, is_valid):
counts = qml.math.array(counts, like=interface)
return counts / qml.math.sum(counts)
if isinstance(measurement, CountsMP):
mcm_samples = [{"".join(str(v) for v in tuple(s)): 1} for s in mcm_samples]
mcm_samples = [{"".join(str(int(v)) for v in tuple(s)): 1} for s in mcm_samples]
return gather_non_mcm(measurement, mcm_samples, is_valid)
mcm_samples = qml.math.ravel(qml.math.array(mv.concretize(samples), like=interface))
if isinstance(measurement, ProbabilityMP):
mcm_samples = qml.math.array(mv.concretize(samples), like=interface).ravel()
counts = [qml.math.sum((mcm_samples == v) * is_valid) for v in list(mv.branches.values())]
counts = qml.math.array(counts, like=interface)
return counts / qml.math.sum(counts)
mcm_samples = qml.math.array([mv.concretize(samples)], like=interface).ravel()
if isinstance(measurement, CountsMP):
mcm_samples = [{s: 1} for s in mcm_samples]
mcm_samples = [{float(s): 1} for s in mcm_samples]
return gather_non_mcm(measurement, mcm_samples, is_valid)
55 changes: 50 additions & 5 deletions tests/devices/default_qubit/test_default_qubit_native_mcm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for default qubit preprocessing."""
from functools import partial, reduce
from functools import reduce
from typing import Iterable, Sequence

import numpy as np
Expand All @@ -24,7 +24,11 @@

pytestmark = pytest.mark.slow

get_device = partial(qml.device, name="default.qubit", seed=8237945)

def get_device(**kwargs):
kwargs.setdefault("shots", None)
kwargs.setdefault("seed", 8237945)
return qml.device("default.qubit", **kwargs)


def validate_counts(shots, results1, results2, batch_size=None):
Expand Down Expand Up @@ -88,7 +92,7 @@ def validate_samples(shots, results1, results2, batch_size=None):
assert results1.ndim == results2.ndim
if results2.ndim > 1:
assert results1.shape[1] == results2.shape[1]
np.allclose(np.sum(results1), np.sum(results2), atol=20, rtol=0.2)
np.allclose(qml.math.sum(results1), qml.math.sum(results2), atol=20, rtol=0.2)


def validate_expval(shots, results1, results2, batch_size=None):
Expand Down Expand Up @@ -611,7 +615,7 @@ def test_sample_with_prng_key(shots, postselect, reset):
# pylint: disable=import-outside-toplevel
from jax.random import PRNGKey

dev = qml.device("default.qubit", shots=shots, seed=PRNGKey(678))
dev = get_device(shots=shots, seed=PRNGKey(678))
param = [np.pi / 4, np.pi / 3]
obs = qml.PauliZ(0) @ qml.PauliZ(1)

Expand Down Expand Up @@ -659,7 +663,7 @@ def test_jax_jit(diff_method, postselect, reset):

shots = 10

dev = qml.device("default.qubit", shots=shots, seed=jax.random.PRNGKey(678))
dev = get_device(shots=shots, seed=jax.random.PRNGKey(678))
params = [np.pi / 2.5, np.pi / 3, -np.pi / 3.5]
obs = qml.PauliY(0)

Expand Down Expand Up @@ -750,3 +754,44 @@ def func(x):
results2 = func2(param)
for r1, r2 in zip(results1.keys(), results2.keys()):
assert r1 == r2


@pytest.mark.torch
@pytest.mark.parametrize("postselect", [None, 1])
@pytest.mark.parametrize("diff_method", [None, "best"])
@pytest.mark.parametrize("measure_f", [qml.probs, qml.sample, qml.expval, qml.var])
@pytest.mark.parametrize("meas_obj", [qml.PauliZ(1), [0, 1], "composite_mcm", "mcm_list"])
def test_torch_integration(postselect, diff_method, measure_f, meas_obj):
"""Test that native MCM circuits are executed correctly with Torch"""
if measure_f in (qml.var, qml.expval) and (
isinstance(meas_obj, list) or meas_obj == "mcm_list"
):
pytest.skip("Can't use wires/mcm lists with var or expval")

import torch

shots = 7000
dev = get_device(shots=shots, seed=123456789)
param = torch.tensor(np.pi / 3, dtype=torch.float64)

@qml.qnode(dev, diff_method=diff_method)
def func(x):
qml.RX(x, 0)
m0 = qml.measure(0)
qml.RX(0.5 * x, 1)
m1 = qml.measure(1, postselect=postselect)
qml.cond((m0 + m1) == 2, qml.RY)(2.0 * x, 0)
m2 = qml.measure(0)

mid_measure = 0.5 * m2 if meas_obj == "composite_mcm" else [m1, m2]
measurement_key = "wires" if isinstance(meas_obj, list) else "op"
measurement_value = mid_measure if isinstance(meas_obj, str) else meas_obj
return measure_f(**{measurement_key: measurement_value})

func1 = func
func2 = qml.defer_measurements(func)

results1 = func1(param)
results2 = func2(param)

validate_measurements(measure_f, shots, results1, results2)
165 changes: 165 additions & 0 deletions tests/transforms/test_dynamic_one_shot.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,168 @@ def test_len_measurements_mcms(measure, aux_measure, n_meas):
assert len(aux_tape.measurements) == n_meas + n_mcms
assert isinstance(aux_tape.measurements[0], aux_measure)
assert all(isinstance(m, SampleMP) for m in aux_tape.measurements[1:])


def assert_results(res, shots, n_mcms):
"""Helper to check that expected raw results of executing the transformed tape are correct"""
assert len(res) == shots
# One for the non-MeasurementValue MP, and the rest of the mid-circuit measurements
assert all(len(r) == n_mcms + 1 for r in res)
# Not validating distribution of results as device sampling unit tests already validate
# that samples are generated correctly.


@pytest.mark.jax
@pytest.mark.parametrize("measure_f", (qml.expval, qml.probs, qml.sample, qml.var))
@pytest.mark.parametrize("shots", [20, [20, 21]])
@pytest.mark.parametrize("n_mcms", [1, 3])
def test_tape_results_jax(shots, n_mcms, measure_f):
"""Test that the simulation results of a tape are correct with jax parameters"""
import jax

dev = qml.device("default.qubit", wires=4, shots=shots, seed=jax.random.PRNGKey(123))
param = jax.numpy.array(np.pi / 2)

mv = qml.measure(0)
mp = mv.measurements[0]

tape = qml.tape.QuantumScript(
[qml.RX(param, 0), mp] + [MidMeasureMP(0, id=str(i)) for i in range(n_mcms - 1)],
[measure_f(op=qml.PauliZ(0)), measure_f(op=mv)],
shots=shots,
)

tapes, _ = qml.dynamic_one_shot(tape)
results = dev.execute(tapes)[0]

# The transformed tape never has a shot vector
if isinstance(shots, list):
shots = sum(shots)

assert_results(results, shots, n_mcms)


@pytest.mark.jax
@pytest.mark.parametrize(
"measure_f, expected1, expected2",
[
(qml.expval, 1.0, 1.0),
(qml.probs, [1, 0], [0, 1]),
(qml.sample, 1, 1),
(qml.var, 0.0, 0.0),
],
)
@pytest.mark.parametrize("shots", [20, [20, 21]])
@pytest.mark.parametrize("n_mcms", [1, 3])
def test_jax_results_processing(shots, n_mcms, measure_f, expected1, expected2):
"""Test that the results of tapes are processed correctly for tapes with jax parameters"""
import jax.numpy as jnp

mv = qml.measure(0)
mp = mv.measurements[0]

tape = qml.tape.QuantumScript(
[qml.RX(1.5, 0), mp] + [MidMeasureMP(0)] * (n_mcms - 1),
[measure_f(op=qml.PauliZ(0)), measure_f(op=mv)],
shots=shots,
)
_, fn = qml.dynamic_one_shot(tape)
all_shots = sum(shots) if isinstance(shots, list) else shots

first_res = jnp.array([1.0, 0.0]) if measure_f == qml.probs else jnp.array(1.0)
rest = jnp.array(1, dtype=int)
single_shot_res = (first_res,) + (rest,) * n_mcms
# Raw results for each shot are (sample_for_first_measurement,) + (sample for 1st MCM, sample for 2nd MCM, ...)
raw_results = (single_shot_res,) * all_shots
raw_results = (raw_results,)
res = fn(raw_results)

if measure_f is qml.sample:
# All samples 1
expected1 = (
[[expected1] * s for s in shots] if isinstance(shots, list) else [expected1] * shots
)
expected2 = (
[[expected2] * s for s in shots] if isinstance(shots, list) else [expected2] * shots
)
else:
expected1 = [expected1 for _ in shots] if isinstance(shots, list) else expected1
expected2 = [expected2 for _ in shots] if isinstance(shots, list) else expected2

if isinstance(shots, list):
assert len(res) == len(shots)
for r, e1, e2 in zip(res, expected1, expected2):
# Expected result is 2-list since we have two measurements in the tape
assert qml.math.allclose(r, [e1, e2])
else:
# Expected result is 2-list since we have two measurements in the tape
assert qml.math.allclose(res, [expected1, expected2])


@pytest.mark.jax
@pytest.mark.parametrize(
"measure_f, expected1, expected2",
[
(qml.expval, 1.0, 1.0),
(qml.probs, [1, 0], [0, 1]),
(qml.sample, 1, 1),
(qml.var, 0.0, 0.0),
],
)
@pytest.mark.parametrize("shots", [20, [20, 22]])
def test_jax_results_postselection_processing(shots, measure_f, expected1, expected2):
"""Test that the results of tapes are processed correctly for tapes with jax parameters
when postselecting"""
import jax.numpy as jnp

param = jnp.array(np.pi / 2)
fill_value = np.iinfo(np.int32).min
mv = qml.measure(0, postselect=1)
mp = mv.measurements[0]

tape = qml.tape.QuantumScript(
[qml.RX(param, 0), mp, MidMeasureMP(0)],
[measure_f(op=qml.PauliZ(0)), measure_f(op=mv)],
shots=shots,
)
_, fn = qml.dynamic_one_shot(tape)
all_shots = sum(shots) if isinstance(shots, list) else shots

# Alternating tuple. Only the values at odd indices are valid
first_res_two_shot = (
(jnp.array([1.0, 0.0]), jnp.array([0.0, 1.0]))
if measure_f == qml.probs
else (jnp.array(1.0), jnp.array(0.0))
)
first_res = first_res_two_shot * (all_shots // 2)
# Tuple of alternating 1s and 0s. Zero is invalid as postselecting on 1
postselect_res = (jnp.array(1, dtype=int), jnp.array(0, dtype=int)) * (all_shots // 2)
rest = (jnp.array(1, dtype=int),) * all_shots
# Raw results for each shot are (sample_for_first_measurement, sample for 1st MCM, sample for 2nd MCM)
raw_results = tuple(zip(first_res, postselect_res, rest))
raw_results = (raw_results,)
res = fn(raw_results)

if measure_f is qml.sample:
expected1 = (
[[expected1, fill_value] * (s // 2) for s in shots]
if isinstance(shots, list)
else [expected1, fill_value] * (shots // 2)
)
expected2 = (
[[expected2, fill_value] * (s // 2) for s in shots]
if isinstance(shots, list)
else [expected2, fill_value] * (shots // 2)
)
else:
expected1 = [expected1 for _ in shots] if isinstance(shots, list) else expected1
expected2 = [expected2 for _ in shots] if isinstance(shots, list) else expected2

if isinstance(shots, list):
assert len(res) == len(shots)
for r, e1, e2 in zip(res, expected1, expected2):
# Expected result is 2-list since we have two measurements in the tape
assert qml.math.allclose(r, [e1, e2])
else:
# Expected result is 2-list since we have two measurements in the tape
assert qml.math.allclose(res, [expected1, expected2])

0 comments on commit 59a1e05

Please sign in to comment.