Skip to content

Commit

Permalink
Allow choice of subdimension in cirq.apply_unitary (#4910)
Browse files Browse the repository at this point in the history
Fixes #2862 by making the unitary subspace configurable.

@Strilanc @dabacon do we want to make this explicit or does it just create extra confusion, especially since it can't do arbitrary subdimensions that can't be represented as a slice?
  • Loading branch information
daxfohl authored Jun 7, 2022
1 parent cb0f9f5 commit abdd763
Show file tree
Hide file tree
Showing 2 changed files with 314 additions and 10 deletions.
72 changes: 62 additions & 10 deletions cirq-core/cirq/protocols/apply_unitary_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,20 @@ class ApplyUnitaryArgs:
dtype as the target tensor.
axes: Which axes the unitary effect is being applied to (e.g. the
qubits that the gate is operating on).
subspaces: Which subspace (in the computational basis) the unitary
effect is being applied to, on each axis. By default it applies
to subspace 0..d-1 on each axis, where d is the dimension of the
unitary effect on that axis. Subspaces on each axis must be
representable as a slice, so the dimensions specified here need to
have a consistent step size.
"""

def __init__(
self, target_tensor: np.ndarray, available_buffer: np.ndarray, axes: Iterable[int]
self,
target_tensor: np.ndarray,
available_buffer: np.ndarray,
axes: Iterable[int],
subspaces: Optional[Sequence[Tuple[int, ...]]] = None,
):
"""Inits ApplyUnitaryArgs.
Expand All @@ -75,11 +85,27 @@ def __init__(
dtype as the target tensor.
axes: Which axes the unitary effect is being applied to (e.g. the
qubits that the gate is operating on).
subspaces: Which subspace (in the computational basis) the unitary
effect is being applied to, on each axis. By default it applies
to subspace 0..d-1 on each axis, where d is the dimension of
the unitary effect on that axis. Subspaces on each axis must be
representable as a slice, so the dimensions specified here need
to have a consistent step size.
Raises:
ValueError: If the subspace count does not equal the axis count, if
any subspace has zero dimensions, or if any subspace has
dimensions specified without a consistent step size.
"""
self.target_tensor = target_tensor
self.available_buffer = available_buffer
self.axes = tuple(axes)
if subspaces is not None:
if len(self.axes) != len(subspaces):
raise ValueError('Subspace count does not match axis count.')
for subspace, axis in zip(subspaces, self.axes):
if any(s >= target_tensor.shape[axis] for s in subspace):
raise ValueError('Subspace specified does not exist in axis.')
self.slices = None if subspaces is None else tuple(map(_to_slice, subspaces))

@staticmethod
def default(
Expand Down Expand Up @@ -125,7 +151,7 @@ def with_axes_transposed_to_start(self) -> 'ApplyUnitaryArgs':
return ApplyUnitaryArgs(target_tensor, available_buffer, range(len(self.axes)))

def _for_operation_with_qid_shape(
self, indices: Iterable[int], qid_shape: Tuple[int, ...]
self, indices: Iterable[int], slices: Tuple[Union[int, slice], ...]
) -> 'ApplyUnitaryArgs':
"""Creates a sliced and transposed view of `self` appropriate for an
operation with shape `qid_shape` on qubits with the given indices.
Expand All @@ -138,14 +164,14 @@ def _for_operation_with_qid_shape(
Args:
indices: Integer indices into `self.axes` specifying which qubits
the operation applies to.
qid_shape: The qid shape of the operation, the expected number of
quantum levels in each qubit the operation applies to.
slices: The slices of the operation, the subdimension in each qubit
the operation applies to.
Returns: A new `ApplyUnitaryArgs` where `sub_args.target_tensor` and
`sub_args.available_buffer` are sliced and transposed views of
`self.target_tensor` and `self.available_buffer` respectively.
"""
slices = [slice(0, size) for size in qid_shape]
slices = tuple(size if isinstance(size, slice) else slice(0, size) for size in slices)
sub_axes = [self.axes[i] for i in indices]
axis_set = set(sub_axes)
other_axes = [axis for axis in range(len(self.target_tensor.shape)) if axis not in axis_set]
Expand Down Expand Up @@ -369,8 +395,12 @@ def _strat_apply_unitary_from_apply_unitary(
func = getattr(unitary_value, '_apply_unitary_', None)
if func is None:
return NotImplemented
op_qid_shape = qid_shape_protocol.qid_shape(unitary_value, (2,) * len(args.axes))
sub_args = args._for_operation_with_qid_shape(range(len(op_qid_shape)), op_qid_shape)
if args.slices is None:
op_qid_shape = qid_shape_protocol.qid_shape(unitary_value, (2,) * len(args.axes))
slices = tuple(slice(0, size) for size in op_qid_shape)
else:
slices = args.slices
sub_args = args._for_operation_with_qid_shape(range(len(slices)), slices)
sub_result = func(sub_args)
if sub_result is NotImplemented or sub_result is None:
return sub_result
Expand All @@ -390,8 +420,15 @@ def _strat_apply_unitary_from_unitary(
if matrix is NotImplemented or matrix is None:
return matrix

val_qid_shape = qid_shape_protocol.qid_shape(unitary_value, default=(2,) * len(args.axes))
sub_args = args._for_operation_with_qid_shape(range(len(val_qid_shape)), val_qid_shape)
if args.slices is None:
val_qid_shape = qid_shape_protocol.qid_shape(unitary_value, default=(2,) * len(args.axes))
slices = tuple(slice(0, size) for size in val_qid_shape)
else:
slices = args.slices
val_qid_shape = tuple(
((s.step if s.stop is None else s.stop) - s.start) // (s.step or 1) for s in slices
)
sub_args = args._for_operation_with_qid_shape(range(len(slices)), slices)
matrix = matrix.astype(sub_args.target_tensor.dtype)
if len(val_qid_shape) == 1 and val_qid_shape[0] <= 2:
# Special case for single-qubit, 2x2 or 1x1 operations.
Expand Down Expand Up @@ -557,3 +594,18 @@ def _incorporate_result_into_target(
return args.available_buffer
sub_args.target_tensor[...] = sub_result
return args.target_tensor


def _to_slice(subspace_def: Tuple[int, ...]):
if len(subspace_def) < 1:
raise ValueError(f'Subspace {subspace_def} has zero dimensions.')

if len(subspace_def) == 1:
return slice(subspace_def[0], subspace_def[0] + 1, 1)

step = subspace_def[1] - subspace_def[0]
for i in range(len(subspace_def) - 1):
if subspace_def[i + 1] - subspace_def[i] != step:
raise ValueError(f'Subspace {subspace_def} does not have consistent step size.')
stop = subspace_def[-1] + step
return slice(subspace_def[0], stop if stop >= 0 else None, step)
252 changes: 252 additions & 0 deletions cirq-core/cirq/protocols/apply_unitary_protocol_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,258 @@ def _unitary_(self):
)


# fmt: off
def test_subspace_size_2():
result = cirq.apply_unitary(
unitary_value=cirq.X,
args=cirq.ApplyUnitaryArgs(
target_tensor=cirq.eye_tensor((3,), dtype=np.complex64),
available_buffer=cirq.eye_tensor((3,), dtype=np.complex64),
axes=(0,),
subspaces=[(0, 1)],
),
)
np.testing.assert_allclose(
result,
np.array(
[
[0, 1, 0],
[1, 0, 0],
[0, 0, 1],
]
),
atol=1e-8,
)

result = cirq.apply_unitary(
unitary_value=cirq.X,
args=cirq.ApplyUnitaryArgs(
target_tensor=cirq.eye_tensor((3,), dtype=np.complex64),
available_buffer=cirq.eye_tensor((3,), dtype=np.complex64),
axes=(0,),
subspaces=[(0, 2)],
),
)
np.testing.assert_allclose(
result,
np.array(
[
[0, 0, 1],
[0, 1, 0],
[1, 0, 0],
]
),
atol=1e-8,
)

result = cirq.apply_unitary(
unitary_value=cirq.X,
args=cirq.ApplyUnitaryArgs(
target_tensor=cirq.eye_tensor((3,), dtype=np.complex64),
available_buffer=cirq.eye_tensor((3,), dtype=np.complex64),
axes=(0,),
subspaces=[(1, 2)],
),
)
np.testing.assert_allclose(
result,
np.array(
[
[1, 0, 0],
[0, 0, 1],
[0, 1, 0],
]
),
atol=1e-8,
)

result = cirq.apply_unitary(
unitary_value=cirq.X,
args=cirq.ApplyUnitaryArgs(
target_tensor=cirq.eye_tensor((4,), dtype=np.complex64),
available_buffer=cirq.eye_tensor((4,), dtype=np.complex64),
axes=(0,),
subspaces=[(1, 2)],
),
)
np.testing.assert_allclose(
result,
np.array(
[
[1, 0, 0, 0],
[0, 0, 1, 0],
[0, 1, 0, 0],
[0, 0, 0, 1],
]
),
atol=1e-8,
)


def test_subspaces_size_3():
plus_one_mod_3_gate = cirq.XPowGate(dimension=3)

result = cirq.apply_unitary(
unitary_value=plus_one_mod_3_gate,
args=cirq.ApplyUnitaryArgs(
target_tensor=cirq.eye_tensor((3,), dtype=np.complex64),
available_buffer=cirq.eye_tensor((3,), dtype=np.complex64),
axes=(0,),
subspaces=[(0, 1, 2)],
),
)
np.testing.assert_allclose(
result,
np.array(
[
[0, 0, 1],
[1, 0, 0],
[0, 1, 0],
]
),
atol=1e-8,
)

result = cirq.apply_unitary(
unitary_value=plus_one_mod_3_gate,
args=cirq.ApplyUnitaryArgs(
target_tensor=cirq.eye_tensor((3,), dtype=np.complex64),
available_buffer=cirq.eye_tensor((3,), dtype=np.complex64),
axes=(0,),
subspaces=[(2, 1, 0)],
),
)
np.testing.assert_allclose(
result,
np.array(
[
[0, 1, 0],
[0, 0, 1],
[1, 0, 0],
]
),
atol=1e-8,
)

result = cirq.apply_unitary(
unitary_value=plus_one_mod_3_gate,
args=cirq.ApplyUnitaryArgs(
target_tensor=cirq.eye_tensor((4,), dtype=np.complex64),
available_buffer=cirq.eye_tensor((4,), dtype=np.complex64),
axes=(0,),
subspaces=[(1, 2, 3)],
),
)
np.testing.assert_allclose(
result,
np.array(
[
[1, 0, 0, 0],
[0, 0, 0, 1],
[0, 1, 0, 0],
[0, 0, 1, 0],
]
),
atol=1e-8,
)


def test_subspaces_size_1():
phase_gate = cirq.MatrixGate(np.array([[1j]]))

result = cirq.apply_unitary(
unitary_value=phase_gate,
args=cirq.ApplyUnitaryArgs(
target_tensor=cirq.eye_tensor((2,), dtype=np.complex64),
available_buffer=cirq.eye_tensor((2,), dtype=np.complex64),
axes=(0,),
subspaces=[(0,)],
),
)
np.testing.assert_allclose(
result,
np.array(
[
[1j, 0],
[0, 1],
]
),
atol=1e-8,
)

result = cirq.apply_unitary(
unitary_value=phase_gate,
args=cirq.ApplyUnitaryArgs(
target_tensor=cirq.eye_tensor((2,), dtype=np.complex64),
available_buffer=cirq.eye_tensor((2,), dtype=np.complex64),
axes=(0,),
subspaces=[(1,)],
),
)
np.testing.assert_allclose(
result,
np.array(
[
[1, 0],
[0, 1j],
]
),
atol=1e-8,
)

result = cirq.apply_unitary(
unitary_value=phase_gate,
args=cirq.ApplyUnitaryArgs(
target_tensor=np.array([[0, 1], [1, 0]], dtype=np.complex64),
available_buffer=np.zeros((2, 2), dtype=np.complex64),
axes=(0,),
subspaces=[(1,)],
),
)
np.testing.assert_allclose(
result,
np.array(
[
[0, 1],
[1j, 0],
]
),
atol=1e-8,
)
# fmt: on


def test_invalid_subspaces():
with pytest.raises(ValueError, match='Subspace specified does not exist in axis'):
_ = cirq.ApplyUnitaryArgs(
target_tensor=cirq.eye_tensor((2,), dtype=np.complex64),
available_buffer=cirq.eye_tensor((2,), dtype=np.complex64),
axes=(0,),
subspaces=[(1, 2)],
)
with pytest.raises(ValueError, match='Subspace count does not match axis count'):
_ = cirq.ApplyUnitaryArgs(
target_tensor=cirq.eye_tensor((2,), dtype=np.complex64),
available_buffer=cirq.eye_tensor((2,), dtype=np.complex64),
axes=(0,),
subspaces=[(0, 1), (0, 1)],
)
with pytest.raises(ValueError, match='has zero dimensions'):
_ = cirq.ApplyUnitaryArgs(
target_tensor=cirq.eye_tensor((2,), dtype=np.complex64),
available_buffer=cirq.eye_tensor((2,), dtype=np.complex64),
axes=(0,),
subspaces=[()],
)
with pytest.raises(ValueError, match='does not have consistent step size'):
_ = cirq.ApplyUnitaryArgs(
target_tensor=cirq.eye_tensor((3,), dtype=np.complex64),
available_buffer=cirq.eye_tensor((3,), dtype=np.complex64),
axes=(0,),
subspaces=[(0, 2, 1)],
)


def test_incorporate_result_not_view():
tensor = np.zeros((2, 2))
tensor2 = np.zeros((2, 2))
Expand Down

0 comments on commit abdd763

Please sign in to comment.