Skip to content

Commit

Permalink
refactored tests
Browse files Browse the repository at this point in the history
  • Loading branch information
NoureldinYosri committed Jun 2, 2023
1 parent cb194ee commit 5f3e52d
Show file tree
Hide file tree
Showing 6 changed files with 135 additions and 119 deletions.
32 changes: 14 additions & 18 deletions cirq-core/cirq/protocols/apply_unitary_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,7 @@
# limitations under the License.
"""A protocol for implementing high performance unitary left-multiplies."""
import warnings
from typing import (
Any,
cast,
Iterable,
Optional,
Sequence,
Tuple,
TYPE_CHECKING,
TypeVar,
Union,
Callable,
)
from typing import Any, cast, Iterable, Optional, Sequence, Tuple, TYPE_CHECKING, TypeVar, Union

import numpy as np
from typing_extensions import Protocol
Expand Down Expand Up @@ -235,6 +224,12 @@ def subspace_index(
qid_shape=self.target_tensor.shape,
)

@classmethod
def for_unitary(cls, qid_shapes: Tuple[int, ...]) -> 'ApplyUnitaryArgs':
state = qis.eye_tensor(qid_shapes, dtype=np.complex128)
buffer = np.empty_like(state)
return ApplyUnitaryArgs(state, buffer, range(len(qid_shapes)))


class SupportsConsistentApplyUnitary(Protocol):
"""An object that can be efficiently left-multiplied into tensors."""
Expand Down Expand Up @@ -285,6 +280,10 @@ def _apply_unitary_(
"""


def _strat_apply_unitary_from_unitary_(val: Any, args: ApplyUnitaryArgs) -> Optional[np.ndarray]:
return _strat_apply_unitary_from_unitary(val, args, matrix=None)


def apply_unitary(
unitary_value: Any,
args: ApplyUnitaryArgs,
Expand Down Expand Up @@ -357,14 +356,14 @@ def apply_unitary(
if len(args.axes) <= 4:
strats = [
_strat_apply_unitary_from_apply_unitary,
_strat_apply_unitary_from_unitary,
_strat_apply_unitary_from_unitary_,
_strat_apply_unitary_from_decompose,
]
else:
strats = [
_strat_apply_unitary_from_apply_unitary,
_strat_apply_unitary_from_decompose,
_strat_apply_unitary_from_unitary,
_strat_apply_unitary_from_unitary_,
]
if not allow_decompose:
strats.remove(_strat_apply_unitary_from_decompose)
Expand All @@ -374,7 +373,6 @@ def apply_unitary(
with warnings.catch_warnings():
warnings.filterwarnings(action="error", category=np.ComplexWarning)
for strat in strats:
strat = cast(Callable[[Any, ApplyUnitaryArgs], Optional[np.ndarray]], strat)
result = strat(unitary_value, args)
if result is None:
break
Expand Down Expand Up @@ -473,12 +471,10 @@ def _strat_apply_unitary_from_decompose(val: Any, args: ApplyUnitaryArgs) -> Opt
return apply_unitaries(operations, qubits, args, None)
ordered_qubits = ancilla + tuple(qubits)
all_qid_shapes = qid_shape_protocol.qid_shape(ordered_qubits)
state = qis.eye_tensor(all_qid_shapes, dtype=np.complex128)
buffer = np.empty_like(state)
result = apply_unitaries(
operations,
ordered_qubits,
ApplyUnitaryArgs(state, buffer, range(len(ordered_qubits))),
ApplyUnitaryArgs.for_unitary(qid_shape_protocol.qid_shape(ordered_qubits)),
None,
)
if result is None or result is NotImplemented:
Expand Down
25 changes: 3 additions & 22 deletions cirq-core/cirq/protocols/apply_unitary_protocol_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import cirq
from cirq.protocols.apply_unitary_protocol import _incorporate_result_into_target
from cirq import testing


def test_apply_unitary_presence_absence():
Expand Down Expand Up @@ -719,41 +720,21 @@ def test_cast_to_complex():
cirq.apply_unitary(y0, args)


class NotDecomposableGate(cirq.Gate):
def num_qubits(self):
return 1


class DecomposableGate(cirq.Gate):
def __init__(self, sub_gate: cirq.Gate, allocate_ancilla: bool) -> None:
super().__init__()
self._sub_gate = sub_gate
self._allocate_ancilla = allocate_ancilla

def num_qubits(self):
return 1

def _decompose_(self, qubits):
if self._allocate_ancilla:
yield cirq.Z(cirq.LineQubit(1))
yield self._sub_gate(qubits[0])


def test_strat_apply_unitary_from_decompose():
state = np.eye(2, dtype=np.complex128)
args = cirq.ApplyUnitaryArgs(
target_tensor=state, available_buffer=np.zeros_like(state), axes=(0,)
)
np.testing.assert_allclose(
cirq.apply_unitaries(
[DecomposableGate(cirq.X, False)(cirq.LineQubit(0))], [cirq.LineQubit(0)], args
[testing.DecomposableGate(cirq.X, False)(cirq.LineQubit(0))], [cirq.LineQubit(0)], args
),
[[0, 1], [1, 0]],
)

with pytest.raises(TypeError):
_ = cirq.apply_unitaries(
[DecomposableGate(NotDecomposableGate(), True)(cirq.LineQubit(0))],
[testing.DecomposableGate(testing.NotDecomposableGate(), True)(cirq.LineQubit(0))],
[cirq.LineQubit(0)],
args,
)
9 changes: 2 additions & 7 deletions cirq-core/cirq/protocols/unitary_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import numpy as np
from typing_extensions import Protocol

from cirq import qis
from cirq._doc import doc_private
from cirq.protocols import qid_shape_protocol
from cirq.protocols.apply_unitary_protocol import ApplyUnitaryArgs, apply_unitaries
Expand Down Expand Up @@ -162,9 +161,7 @@ def _strat_unitary_from_apply_unitary(val: Any) -> Optional[np.ndarray]:
return NotImplemented

# Apply unitary effect to an identity matrix.
state = qis.eye_tensor(val_qid_shape, dtype=np.complex128)
buffer = np.empty_like(state)
result = method(ApplyUnitaryArgs(state, buffer, range(len(val_qid_shape))))
result = method(ApplyUnitaryArgs.for_unitary(val_qid_shape))

if result is NotImplemented or result is None:
return result
Expand All @@ -187,10 +184,8 @@ def _strat_unitary_from_decompose(val: Any) -> Optional[np.ndarray]:
val_qid_shape = qid_shape_protocol.qid_shape(ancillas) + val_qid_shape

# Apply sub-operations' unitary effects to an identity matrix.
state = qis.eye_tensor(val_qid_shape, dtype=np.complex128)
buffer = np.empty_like(state)
result = apply_unitaries(
operations, ordered_qubits, ApplyUnitaryArgs(state, buffer, range(len(val_qid_shape))), None
operations, ordered_qubits, ApplyUnitaryArgs.for_unitary(val_qid_shape), None
)

# Package result.
Expand Down
81 changes: 9 additions & 72 deletions cirq-core/cirq/protocols/unitary_protocol_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, cast
import functools
from typing import cast, Optional

import numpy as np
import pytest

import cirq
from cirq import testing

m0: np.ndarray = np.array([])
# yapf: disable
Expand Down Expand Up @@ -94,69 +94,6 @@ def _decompose_(self, qubits):
yield FullyImplemented(self.unitary_value).on(qubits[0])


class GateThatAllocatesAQubit(cirq.Gate):
def __init__(self, theta: float) -> None:
super().__init__()
self._theta = theta

def _num_qubits_(self):
return 1

def _decompose_(self, q):
anc = cirq.NamedQubit("anc")
yield cirq.CX(*q, anc)
yield (cirq.Z**self._theta)(anc)
yield cirq.CX(*q, anc)

def target_unitary(self) -> np.ndarray:
return np.array([[1, 0], [0, (-1 + 0j) ** self._theta]])


class GateThatAllocatesTwoQubits(cirq.Gate):
def _num_qubits_(self):
return 2

def _decompose_(self, qs):
q0, q1 = qs
anc = cirq.NamedQubit.range(2, prefix='two_ancillas_')

yield cirq.X(anc[0])
yield cirq.CX(q0, anc[0])
yield (cirq.Y)(anc[0])
yield cirq.CX(q0, anc[0])

yield cirq.CX(q1, anc[1])
yield (cirq.Z)(anc[1])
yield cirq.CX(q1, anc[1])

@classmethod
def target_unitary(cls) -> np.ndarray:
# Unitary = (-j I_2) \otimes Z
return np.array([[-1j, 0, 0, 0], [0, 1j, 0, 0], [0, 0, 1j, 0], [0, 0, 0, -1j]])


class GateThatDecomposesIntoNGates(cirq.Gate):
def __init__(self, n: int, sub_gate: cirq.Gate, theta: float) -> None:
super().__init__()
self._n = n
self._subgate = sub_gate
self._name = str(sub_gate)
self._theta = theta

def _num_qubits_(self) -> int:
return self._n

def _decompose_(self, qs):
ancilla = cirq.NamedQubit.range(self._n, prefix=self._name)
yield self._subgate.on_each(ancilla)
yield (cirq.Z**self._theta).on_each(qs)
yield self._subgate.on_each(ancilla)

def target_unitary(self) -> np.ndarray:
U = np.array([[1, 0], [0, (-1 + 0j) ** self._theta]])
return functools.reduce(np.kron, [U] * self._n)


class DecomposableOperation(cirq.Operation):
qubits = ()
with_qubits = NotImplemented
Expand Down Expand Up @@ -256,7 +193,7 @@ def test_has_unitary():
def test_decompose_gate_that_allocates_qubits(theta: float):
from cirq.protocols.unitary_protocol import _strat_unitary_from_decompose

gate = GateThatAllocatesAQubit(theta)
gate = testing.GateThatAllocatesAQubit(theta)
np.testing.assert_allclose(
cast(np.ndarray, _strat_unitary_from_decompose(gate)), gate.target_unitary()
)
Expand All @@ -270,8 +207,8 @@ def test_decompose_gate_that_allocates_qubits(theta: float):
def test_recusive_decomposition(n: int, theta: float):
from cirq.protocols.unitary_protocol import _strat_unitary_from_decompose

g1 = GateThatDecomposesIntoNGates(n, cirq.H, theta)
g2 = GateThatDecomposesIntoNGates(n, g1, theta)
g1 = testing.GateThatDecomposesIntoNGates(n, cirq.H, theta)
g2 = testing.GateThatDecomposesIntoNGates(n, g1, theta)
np.testing.assert_allclose(
cast(np.ndarray, _strat_unitary_from_decompose(g2)), g2.target_unitary()
)
Expand All @@ -291,12 +228,12 @@ def test_decompose_and_get_unitary():
np.testing.assert_allclose(_strat_unitary_from_decompose(OtherComposite()), m2)

np.testing.assert_allclose(
_strat_unitary_from_decompose(GateThatAllocatesTwoQubits()),
GateThatAllocatesTwoQubits.target_unitary(),
_strat_unitary_from_decompose(testing.GateThatAllocatesTwoQubits()),
testing.GateThatAllocatesTwoQubits.target_unitary(),
)
np.testing.assert_allclose(
_strat_unitary_from_decompose(GateThatAllocatesTwoQubits().on(a, b)),
GateThatAllocatesTwoQubits.target_unitary(),
_strat_unitary_from_decompose(testing.GateThatAllocatesTwoQubits().on(a, b)),
testing.GateThatAllocatesTwoQubits.target_unitary(),
)


Expand Down
8 changes: 8 additions & 0 deletions cirq-core/cirq/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,11 @@
)

from cirq.testing.sample_circuits import nonoptimal_toffoli_circuit

from cirq.testing.sample_gates import (
DecomposableGate,
NotDecomposableGate,
GateThatAllocatesAQubit,
GateThatAllocatesTwoQubits,
GateThatDecomposesIntoNGates,
)
Loading

0 comments on commit 5f3e52d

Please sign in to comment.