Skip to content

Commit

Permalink
Speed up hashing for GridQubit, LineQubit, and NamedQubit (quantumlib…
Browse files Browse the repository at this point in the history
  • Loading branch information
maffoo authored Nov 17, 2023
1 parent 065272c commit 4efa93d
Show file tree
Hide file tree
Showing 4 changed files with 172 additions and 99 deletions.
87 changes: 49 additions & 38 deletions cirq/devices/grid_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import numpy as np

from cirq import _compat, ops, protocols
from cirq import ops, protocols

if TYPE_CHECKING:
import cirq
Expand All @@ -29,9 +29,43 @@
class _BaseGridQid(ops.Qid):
"""The Base class for `GridQid` and `GridQubit`."""

def __init__(self, row: int, col: int):
self._row = row
self._col = col
_row: int
_col: int
_dimension: int
_hash: Optional[int] = None

def __getstate__(self):
# Don't save hash when pickling; see #3777.
state = self.__dict__
if "_hash" in state:
state = state.copy()
del state["_hash"]
return state

def __hash__(self) -> int:
if self._hash is None:
self._hash = hash((self._row, self._col, self._dimension))
return self._hash

def __eq__(self, other):
# Explicitly implemented for performance (vs delegating to Qid).
if isinstance(other, _BaseGridQid):
return (
self._row == other._row
and self._col == other._col
and self._dimension == other._dimension
)
return NotImplemented

def __ne__(self, other):
# Explicitly implemented for performance (vs delegating to Qid).
if isinstance(other, _BaseGridQid):
return (
self._row != other._row
or self._col != other._col
or self._dimension != other._dimension
)
return NotImplemented

def _comparison_key(self):
return self._row, self._col
Expand All @@ -44,6 +78,10 @@ def row(self) -> int:
def col(self) -> int:
return self._col

@property
def dimension(self) -> int:
return self._dimension

def with_dimension(self, dimension: int) -> 'GridQid':
return GridQid(self._row, self._col, dimension=dimension)

Expand Down Expand Up @@ -149,13 +187,10 @@ def __init__(self, row: int, col: int, *, dimension: int) -> None:
dimension: The dimension of the qid's Hilbert space, i.e.
the number of quantum levels.
"""
super().__init__(row, col)
self._dimension = dimension
self.validate_dimension(dimension)

@property
def dimension(self):
return self._dimension
self._row = row
self._col = col
self._dimension = dimension

def _with_row_col(self, row: int, col: int) -> 'GridQid':
return GridQid(row, col, dimension=self.dimension)
Expand Down Expand Up @@ -288,35 +323,11 @@ class GridQubit(_BaseGridQid):
cirq.GridQubit(5, 4)
"""

def __getstate__(self):
# Don't save hash when pickling; see #3777.
state = self.__dict__
hash_key = _compat._method_cache_name(self.__hash__)
if hash_key in state:
state = state.copy()
del state[hash_key]
return state

@_compat.cached_method
def __hash__(self) -> int:
# Explicitly cached for performance (vs delegating to Qid).
return super().__hash__()
_dimension = 2

def __eq__(self, other):
# Explicitly implemented for performance (vs delegating to Qid).
if isinstance(other, GridQubit):
return self._row == other._row and self._col == other._col
return NotImplemented

def __ne__(self, other):
# Explicitly implemented for performance (vs delegating to Qid).
if isinstance(other, GridQubit):
return self._row != other._row or self._col != other._col
return NotImplemented

@property
def dimension(self) -> int:
return 2
def __init__(self, row: int, col: int) -> None:
self._row = row
self._col = col

def _with_row_col(self, row: int, col: int):
return GridQubit(row, col)
Expand Down
4 changes: 1 addition & 3 deletions cirq/devices/grid_qubit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import pytest

import cirq
from cirq import _compat


def test_init():
Expand All @@ -45,8 +44,7 @@ def test_pickled_hash():
q = cirq.GridQubit(3, 4)
q_bad = cirq.GridQubit(3, 4)
_ = hash(q_bad) # compute hash to ensure it is cached.
hash_key = _compat._method_cache_name(cirq.GridQubit.__hash__)
setattr(q_bad, hash_key, getattr(q_bad, hash_key) + 1)
q_bad._hash = q_bad._hash + 1
assert q_bad == q
assert hash(q_bad) != hash(q)
data = pickle.dumps(q_bad)
Expand Down
107 changes: 67 additions & 40 deletions cirq/devices/line_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,48 @@
class _BaseLineQid(ops.Qid):
"""The base class for `LineQid` and `LineQubit`."""

def __init__(self, x: int) -> None:
"""Initializes a line qubit at the given x coordinate."""
self._x = x
_x: int
_dimension: int
_hash: Optional[int] = None

def __getstate__(self):
# Don't save hash when pickling; see #3777.
state = self.__dict__
if "_hash" in state:
state = state.copy()
del state["_hash"]
return state

def __hash__(self) -> int:
if self._hash is None:
self._hash = hash((self._x, self._dimension))
return self._hash

def __eq__(self, other):
# Explicitly implemented for performance (vs delegating to Qid).
if isinstance(other, _BaseLineQid):
return self._x == other._x and self._dimension == other._dimension
return NotImplemented

def __ne__(self, other):
# Explicitly implemented for performance (vs delegating to Qid).
if isinstance(other, _BaseLineQid):
return self._x != other._x or self._dimension != other._dimension
return NotImplemented

def _comparison_key(self):
return self.x
return self._x

@property
def x(self) -> int:
return self._x

@property
def dimension(self) -> int:
return self._dimension

def with_dimension(self, dimension: int) -> 'LineQid':
return LineQid(self.x, dimension)
return LineQid(self._x, dimension)

def is_adjacent(self, other: 'cirq.Qid') -> bool:
"""Determines if two qubits are adjacent line qubits.
Expand All @@ -49,49 +78,45 @@ def is_adjacent(self, other: 'cirq.Qid') -> bool:
Returns: True iff other and self are adjacent.
"""
return isinstance(other, _BaseLineQid) and abs(self.x - other.x) == 1
return isinstance(other, _BaseLineQid) and abs(self._x - other._x) == 1

def neighbors(self, qids: Optional[Iterable[ops.Qid]] = None) -> Set['_BaseLineQid']:
"""Returns qubits that are potential neighbors to this LineQubit
Args:
qids: optional Iterable of qubits to constrain neighbors to.
"""
neighbors = set()
for q in [self - 1, self + 1]:
if qids is None or q in qids:
neighbors.add(q)
return neighbors
return {q for q in [self - 1, self + 1] if qids is None or q in qids}

@abc.abstractmethod
def _with_x(self, x: int) -> Self:
"""Returns a qubit with the same type but a different value of `x`."""

def __add__(self, other: Union[int, Self]) -> Self:
if isinstance(other, _BaseLineQid):
if self.dimension != other.dimension:
if self._dimension != other._dimension:
raise TypeError(
"Can only add LineQids with identical dimension. "
f"Got {self.dimension} and {other.dimension}"
f"Got {self._dimension} and {other._dimension}"
)
return self._with_x(x=self.x + other.x)
return self._with_x(x=self._x + other._x)
if not isinstance(other, int):
raise TypeError(f"Can only add ints and {type(self).__name__}. Instead was {other}")
return self._with_x(self.x + other)
return self._with_x(self._x + other)

def __sub__(self, other: Union[int, Self]) -> Self:
if isinstance(other, _BaseLineQid):
if self.dimension != other.dimension:
if self._dimension != other._dimension:
raise TypeError(
"Can only subtract LineQids with identical dimension. "
f"Got {self.dimension} and {other.dimension}"
f"Got {self._dimension} and {other._dimension}"
)
return self._with_x(x=self.x - other.x)
return self._with_x(x=self._x - other._x)
if not isinstance(other, int):
raise TypeError(
f"Can only subtract ints and {type(self).__name__}. Instead was {other}"
)
return self._with_x(self.x - other)
return self._with_x(self._x - other)

def __radd__(self, other: int) -> Self:
return self + other
Expand All @@ -100,16 +125,16 @@ def __rsub__(self, other: int) -> Self:
return -self + other

def __neg__(self) -> Self:
return self._with_x(-self.x)
return self._with_x(-self._x)

def __complex__(self) -> complex:
return complex(self.x)
return complex(self._x)

def __float__(self) -> float:
return float(self.x)
return float(self._x)

def __int__(self) -> int:
return int(self.x)
return int(self._x)


class LineQid(_BaseLineQid):
Expand Down Expand Up @@ -137,16 +162,12 @@ def __init__(self, x: int, dimension: int) -> None:
dimension: The dimension of the qid's Hilbert space, i.e.
the number of quantum levels.
"""
super().__init__(x)
self._dimension = dimension
self.validate_dimension(dimension)

@property
def dimension(self):
return self._dimension
self._x = x
self._dimension = dimension

def _with_x(self, x: int) -> 'LineQid':
return LineQid(x, dimension=self.dimension)
return LineQid(x, dimension=self._dimension)

@staticmethod
def range(*range_args, dimension: int) -> List['LineQid']:
Expand Down Expand Up @@ -192,15 +213,15 @@ def for_gate(val: Any, start: int = 0, step: int = 1) -> List['LineQid']:
return LineQid.for_qid_shape(qid_shape(val), start=start, step=step)

def __repr__(self) -> str:
return f"cirq.LineQid({self.x}, dimension={self.dimension})"
return f"cirq.LineQid({self._x}, dimension={self._dimension})"

def __str__(self) -> str:
return f"q({self.x}) (d={self.dimension})"
return f"q({self._x}) (d={self._dimension})"

def _circuit_diagram_info_(
self, args: 'cirq.CircuitDiagramInfoArgs'
) -> 'cirq.CircuitDiagramInfo':
return protocols.CircuitDiagramInfo(wire_symbols=(f"{self.x} (d={self.dimension})",))
return protocols.CircuitDiagramInfo(wire_symbols=(f"{self._x} (d={self._dimension})",))

def _json_dict_(self) -> Dict[str, Any]:
return protocols.obj_to_dict_helper(self, ['x', 'dimension'])
Expand All @@ -223,9 +244,15 @@ class LineQubit(_BaseLineQid):
"""

@property
def dimension(self) -> int:
return 2
_dimension = 2

def __init__(self, x: int) -> None:
"""Initializes a line qubit at the given x coordinate.
Args:
x: The x coordinate.
"""
self._x = x

def _with_x(self, x: int) -> 'LineQubit':
return LineQubit(x)
Expand All @@ -234,7 +261,7 @@ def _cmp_tuple(self):
cls = LineQid if type(self) is LineQubit else type(self)
# Must be the same as Qid._cmp_tuple but with cls in place of
# type(self).
return (cls.__name__, repr(cls), self._comparison_key(), self.dimension)
return (cls.__name__, repr(cls), self._comparison_key(), self._dimension)

@staticmethod
def range(*range_args) -> List['LineQubit']:
Expand All @@ -249,15 +276,15 @@ def range(*range_args) -> List['LineQubit']:
return [LineQubit(i) for i in range(*range_args)]

def __repr__(self) -> str:
return f"cirq.LineQubit({self.x})"
return f"cirq.LineQubit({self._x})"

def __str__(self) -> str:
return f"q({self.x})"
return f"q({self._x})"

def _circuit_diagram_info_(
self, args: 'cirq.CircuitDiagramInfoArgs'
) -> 'cirq.CircuitDiagramInfo':
return protocols.CircuitDiagramInfo(wire_symbols=(f"{self.x}",))
return protocols.CircuitDiagramInfo(wire_symbols=(f"{self._x}",))

def _json_dict_(self) -> Dict[str, Any]:
return protocols.obj_to_dict_helper(self, ['x'])
Loading

0 comments on commit 4efa93d

Please sign in to comment.