Skip to content

Commit

Permalink
Do not allow creating registers with bitsize 0 (#6298)
Browse files Browse the repository at this point in the history
* Do not allow creating registers with bitsize 0

* Fix mypy errors
  • Loading branch information
tanujkhattar authored Sep 25, 2023
1 parent 8e4e7d1 commit acbc624
Show file tree
Hide file tree
Showing 8 changed files with 25 additions and 11 deletions.
6 changes: 5 additions & 1 deletion cirq-ft/cirq_ft/algos/and_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,11 @@ def _decompose_via_tree(
def decompose_from_registers(
self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid]
) -> cirq.OP_TREE:
control, ancilla, target = quregs['control'], quregs['ancilla'], quregs['target']
control, ancilla, target = (
quregs['control'],
quregs.get('ancilla', np.array([])),
quregs['target'],
)
if len(self.cv) == 2:
yield self._decompose_single_and(
self.cv[0], self.cv[1], control[0], control[1], *target
Expand Down
2 changes: 1 addition & 1 deletion cirq-ft/cirq_ft/algos/apply_gate_to_lth_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def selection_registers(self) -> Tuple[infra.SelectionRegister, ...]:

@cached_property
def target_registers(self) -> Tuple[infra.Register, ...]:
total_iteration_size = np.product(
total_iteration_size = np.prod(
tuple(reg.iteration_length for reg in self.selection_registers)
)
return (infra.Register('target', int(total_iteration_size)),)
Expand Down
2 changes: 1 addition & 1 deletion cirq-ft/cirq_ft/algos/prepare_uniform_superposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def decompose_from_registers(
context: cirq.DecompositionContext,
**quregs: NDArray[cirq.Qid], # type:ignore[type-var]
) -> cirq.OP_TREE:
controls, target = quregs['controls'], quregs['target']
controls, target = quregs.get('controls', ()), quregs['target']
# Find K and L as per https://arxiv.org/abs/1805.03662 Fig 12.
n, k = self.n, 0
while n > 1 and n % 2 == 0:
Expand Down
6 changes: 4 additions & 2 deletions cirq-ft/cirq_ft/algos/qrom.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,9 @@ def selection_registers(self) -> Tuple[infra.SelectionRegister, ...]:

@cached_property
def target_registers(self) -> Tuple[infra.Register, ...]:
return tuple(infra.Register(f'target{i}', l) for i, l in enumerate(self.target_bitsizes))
return tuple(
infra.Register(f'target{i}', l) for i, l in enumerate(self.target_bitsizes) if l
)

def __repr__(self) -> str:
data_repr = f"({','.join(cirq._compat.proper_repr(d) for d in self.data)})"
Expand All @@ -129,7 +131,7 @@ def _load_nth_data(
**target_regs: NDArray[cirq.Qid], # type: ignore[type-var]
) -> cirq.OP_TREE:
for i, d in enumerate(self.data):
target = target_regs[f'target{i}']
target = target_regs.get(f'target{i}', ())
for q, bit in zip(target, f'{int(d[selection_idx]):0{len(target)}b}'):
if int(bit):
yield gate(q)
Expand Down
2 changes: 1 addition & 1 deletion cirq-ft/cirq_ft/algos/selected_majorana_fermion.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def selection_registers(self) -> Tuple[infra.SelectionRegister, ...]:

@cached_property
def target_registers(self) -> Tuple[infra.Register, ...]:
total_iteration_size = np.product(
total_iteration_size = np.prod(
tuple(reg.iteration_length for reg in self.selection_registers)
)
return (infra.Register('target', int(total_iteration_size)),)
Expand Down
2 changes: 1 addition & 1 deletion cirq-ft/cirq_ft/algos/state_preparation.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def decompose_from_registers(
**quregs: NDArray[cirq.Qid], # type:ignore[type-var]
) -> cirq.OP_TREE:
selection, less_than_equal = quregs['selection'], quregs['less_than_equal']
sigma_mu, alt, keep = quregs['sigma_mu'], quregs['alt'], quregs['keep']
sigma_mu, alt, keep = quregs.get('sigma_mu', ()), quregs['alt'], quregs.get('keep', ())
N = self.selection_registers[0].iteration_length
yield prepare_uniform_superposition.PrepareUniformSuperposition(N).on(*selection)
yield cirq.H.on_each(*sigma_mu)
Expand Down
11 changes: 8 additions & 3 deletions cirq-ft/cirq_ft/infra/gate_with_registers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,16 @@ class Register:
"""

name: str
bitsize: int
bitsize: int = attr.field()
shape: Tuple[int, ...] = attr.field(
converter=lambda v: (v,) if isinstance(v, int) else tuple(v), default=()
)

@bitsize.validator
def bitsize_validator(self, attribute, value):
if value <= 0:
raise ValueError(f"Bitsize for {self=} must be a positive integer. Found {value}.")

def all_idxs(self) -> Iterable[Tuple[int, ...]]:
"""Iterate over all possible indices of a multidimensional register."""
yield from itertools.product(*[range(sh) for sh in self.shape])
Expand All @@ -46,7 +51,7 @@ def total_bits(self) -> int:
This is the product of each of the dimensions in `shape`.
"""
return self.bitsize * int(np.product(self.shape))
return self.bitsize * int(np.prod(self.shape))

def __repr__(self):
return f'cirq_ft.Register(name="{self.name}", bitsize={self.bitsize}, shape={self.shape})'
Expand Down Expand Up @@ -137,7 +142,7 @@ def __repr__(self):

@classmethod
def build(cls, **registers: int) -> 'Registers':
return cls(Register(name=k, bitsize=v) for k, v in registers.items())
return cls(Register(name=k, bitsize=v) for k, v in registers.items() if v > 0)

@overload
def __getitem__(self, key: int) -> Register:
Expand Down
5 changes: 4 additions & 1 deletion cirq-ft/cirq_ft/infra/gate_with_registers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ def test_register():
assert r.bitsize == 5
assert r.shape == (1, 2)

with pytest.raises(ValueError, match="must be a positive integer"):
_ = cirq_ft.Register("zero bitsize register", bitsize=0)


def test_registers():
r1 = cirq_ft.Register("r1", 5)
Expand Down Expand Up @@ -96,7 +99,7 @@ def test_selection_registers_indexing(n, N, m, M):
assert np.ravel_multi_index((x, y), (N, M)) == x * M + y
assert np.unravel_index(x * M + y, (N, M)) == (x, y)

assert np.product(tuple(reg.iteration_length for reg in regs)) == N * M
assert np.prod(tuple(reg.iteration_length for reg in regs)) == N * M


def test_selection_registers_consistent():
Expand Down

0 comments on commit acbc624

Please sign in to comment.