Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle confusion matrices in deferred measurements #5851

Merged
merged 27 commits into from
Oct 12, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Add confusion channel
  • Loading branch information
daxfohl committed Sep 2, 2022
commit 5820b18c45a06ae42adb40a0c8beda08dc117e67
64 changes: 57 additions & 7 deletions cirq-core/cirq/transformers/measurement_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
# limitations under the License.

import itertools
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union
from typing import Any, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union

import numpy as np

from cirq import ops, protocols, value
from cirq.transformers import transformer_api, transformer_primitives
Expand Down Expand Up @@ -96,17 +98,18 @@ def defer(op: 'cirq.Operation', _) -> 'cirq.OP_TREE':
return op
gate = op.gate
if isinstance(gate, ops.MeasurementGate):
if gate.confusion_map:
raise NotImplementedError(
"Deferring confused measurement is not implemented, but found "
f"measurement with key={gate.key} and non-empty confusion map."
)
key = value.MeasurementKey.parse_serialized(gate.key)
targets = [_MeasurementQid(key, q) for q in op.qubits]
measurement_qubits[key] = targets
cxs = [ops.CX(q, target) for q, target in zip(op.qubits, targets)]
confusions = [
_ConfusionChannel(m, [op.qubits[i].dimension for i in indexes]).on(
*[targets[i] for i in indexes]
)
for indexes, m in gate.confusion_map.items()
]
xs = [ops.X(targets[i]) for i, b in enumerate(gate.full_invert_mask()) if b]
return cxs + xs
return cxs + confusions + xs
daxfohl marked this conversation as resolved.
Show resolved Hide resolved
elif protocols.is_measurement(op):
return [defer(op, None) for op in protocols.decompose_once(op)]
elif op.classical_controls:
Expand Down Expand Up @@ -227,3 +230,50 @@ def flip_inversion(op: 'cirq.Operation', _) -> 'cirq.OP_TREE':
return transformer_primitives.map_operations(
circuit, flip_inversion, deep=context.deep if context else True, tags_to_ignore=ignored
).unfreeze()


@value.value_equality
class _ConfusionChannel(ops.Gate):
"""The quantum equivalent of a confusion matrix.

For a classical confusion matrix, the quantum equivalent is a Kraus channel can be calculated
daxfohl marked this conversation as resolved.
Show resolved Hide resolved
by transposing the matrix, square-rooting each term, and forming a Kraus sequence of each term
daxfohl marked this conversation as resolved.
Show resolved Hide resolved
individually and the rest zeroed out. For example,

[[0.8, 0.2],
[0.1, 0.9]]

can be represented as

[[[sqrt(0.8), 0],
[0, 0]],
[[0, sqrt(0.1)],
[0, 0]],
[[0, 0],
[sqrt(0.2), 0]],
[[0, 0],
[0, sqrt(0.9)]]]"""
def __init__(self, confusion_map: np.ndarray, shape: Sequence[int]):
kraus = []
R, C = confusion_map.shape
for r in range(R):
for c in range(C):
v = confusion_map[r, c]
if v != 0:
m = np.zeros(confusion_map.shape)
m[c, r] = np.sqrt(v)
kraus.append(m)
self._shape = tuple(shape)
self._kraus = tuple(kraus)

def _qid_shape_(self) -> Tuple[int, ...]:
return self._shape

def _kraus_(self) -> Tuple[np.ndarray, ...]:
viathor marked this conversation as resolved.
Show resolved Hide resolved
return self._kraus

def _has_kraus_(self) -> bool:
return True

def _value_equality_values_(self):
return self._kraus, self._shape
18 changes: 13 additions & 5 deletions cirq-core/cirq/transformers/measurement_transformers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,13 +307,21 @@ def test_sympy_control():
def test_confusion_map():
q0, q1 = cirq.LineQubit.range(2)
circuit = cirq.Circuit(
cirq.measure(q0, q1, key='a', confusion_map={(0,): np.array([[0.9, 0.1], [0.1, 0.9]])}),
cirq.H(q0),
cirq.measure(q0, key='a', confusion_map={(0,): np.array([[0.8, 0.2], [0.1, 0.9]])}),
cirq.X(q1).with_classical_controls('a'),
cirq.measure(q1, key='b'),
)
with pytest.raises(
NotImplementedError, match='Deferring confused measurement is not implemented'
):
_ = cirq.defer_measurements(circuit)
deferred = cirq.defer_measurements(circuit)

# We use DM simulator because the deferred circuit has channels
sim = cirq.DensityMatrixSimulator()

# 10K samples would take a long time if we had not deferred the measurements, as we'd have to
# run 10K simulations. Here with DM simulator it's 100ms.
result = sim.sample(deferred, repetitions=10_000)
assert 5_100 <= np.sum(result['a']) <= 5_900
assert np.all(result['a'] == result['b'])


def test_dephase():
Expand Down