Skip to content

Commit

Permalink
Use frozensets for key protocols (quantumlib#5560)
Browse files Browse the repository at this point in the history
Fixes quantumlib#5557, using frozenset in key protocol dunder methods, thus allowing the top-level protocol method to avoid defensive copies. Provide a temporary warning until v0.16 for third-party dunders to do the same. This speeds up tight loops about 15% in my testing. @95-martin-orion 

Note this is not breaking, as FrozenSet is a subclass of AbstractSet. (AbstractSet is a readonly interface of a set, and is superclass of both FrozenSet and plain Set). Changed the return types to FrozenSet to be more explicit and make chaining easier.
  • Loading branch information
daxfohl authored and rht committed May 1, 2023
1 parent 6603b0e commit 4d30756
Show file tree
Hide file tree
Showing 10 changed files with 123 additions and 79 deletions.
20 changes: 12 additions & 8 deletions cirq-core/cirq/circuits/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -918,26 +918,30 @@ def qid_shape(
qids = ops.QubitOrder.as_qubit_order(qubit_order).order_for(self.all_qubits())
return protocols.qid_shape(qids)

def all_measurement_key_objs(self) -> AbstractSet['cirq.MeasurementKey']:
return {key for op in self.all_operations() for key in protocols.measurement_key_objs(op)}
def all_measurement_key_objs(self) -> FrozenSet['cirq.MeasurementKey']:
return frozenset(
key for op in self.all_operations() for key in protocols.measurement_key_objs(op)
)

def _measurement_key_objs_(self) -> AbstractSet['cirq.MeasurementKey']:
def _measurement_key_objs_(self) -> FrozenSet['cirq.MeasurementKey']:
"""Returns the set of all measurement keys in this circuit.
Returns: AbstractSet of `cirq.MeasurementKey` objects that are
Returns: FrozenSet of `cirq.MeasurementKey` objects that are
in this circuit.
"""
return self.all_measurement_key_objs()

def all_measurement_key_names(self) -> AbstractSet[str]:
def all_measurement_key_names(self) -> FrozenSet[str]:
"""Returns the set of all measurement key names in this circuit.
Returns: AbstractSet of strings that are the measurement key
Returns: FrozenSet of strings that are the measurement key
names in this circuit.
"""
return {key for op in self.all_operations() for key in protocols.measurement_key_names(op)}
return frozenset(
key for op in self.all_operations() for key in protocols.measurement_key_names(op)
)

def _measurement_key_names_(self) -> AbstractSet[str]:
def _measurement_key_names_(self) -> FrozenSet[str]:
return self.all_measurement_key_names()

def _with_measurement_key_mapping_(self, key_map: Dict[str, str]):
Expand Down
29 changes: 15 additions & 14 deletions cirq-core/cirq/circuits/circuit_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
"""
import math
from typing import (
AbstractSet,
Callable,
Mapping,
Sequence,
Expand Down Expand Up @@ -309,30 +308,32 @@ def _ensure_deterministic_loop_count(self):
raise ValueError('Cannot unroll circuit due to nondeterministic repetitions')

@cached_property
def _measurement_key_objs(self) -> AbstractSet['cirq.MeasurementKey']:
def _measurement_key_objs(self) -> FrozenSet['cirq.MeasurementKey']:
circuit_keys = protocols.measurement_key_objs(self.circuit)
if circuit_keys and self.use_repetition_ids:
self._ensure_deterministic_loop_count()
if self.repetition_ids is not None:
circuit_keys = {
circuit_keys = frozenset(
key.with_key_path_prefix(repetition_id)
for repetition_id in self.repetition_ids
for key in circuit_keys
}
circuit_keys = {key.with_key_path_prefix(*self.parent_path) for key in circuit_keys}
return {
)
circuit_keys = frozenset(
key.with_key_path_prefix(*self.parent_path) for key in circuit_keys
)
return frozenset(
protocols.with_measurement_key_mapping(key, dict(self.measurement_key_map))
for key in circuit_keys
}
)

def _measurement_key_objs_(self) -> AbstractSet['cirq.MeasurementKey']:
def _measurement_key_objs_(self) -> FrozenSet['cirq.MeasurementKey']:
return self._measurement_key_objs

def _measurement_key_names_(self) -> AbstractSet[str]:
return {str(key) for key in self._measurement_key_objs_()}
def _measurement_key_names_(self) -> FrozenSet[str]:
return frozenset(str(key) for key in self._measurement_key_objs_())

@cached_property
def _control_keys(self) -> AbstractSet['cirq.MeasurementKey']:
def _control_keys(self) -> FrozenSet['cirq.MeasurementKey']:
keys = (
frozenset()
if not protocols.control_keys(self.circuit)
Expand All @@ -342,13 +343,13 @@ def _control_keys(self) -> AbstractSet['cirq.MeasurementKey']:
keys |= frozenset(self.repeat_until.keys) - self._measurement_key_objs_()
return keys

def _control_keys_(self) -> AbstractSet['cirq.MeasurementKey']:
def _control_keys_(self) -> FrozenSet['cirq.MeasurementKey']:
return self._control_keys

def _is_parameterized_(self) -> bool:
return any(self._parameter_names_generator())

def _parameter_names_(self) -> AbstractSet[str]:
def _parameter_names_(self) -> FrozenSet[str]:
return frozenset(self._parameter_names_generator())

def _parameter_names_generator(self) -> Iterator[str]:
Expand Down Expand Up @@ -463,7 +464,7 @@ def __str__(self):
)
args = []

def dict_str(d: Dict) -> str:
def dict_str(d: Mapping) -> str:
pairs = [f'{k}: {v}' for k, v in sorted(d.items())]
return '{' + ', '.join(pairs) + '}'

Expand Down
31 changes: 10 additions & 21 deletions cirq-core/cirq/circuits/frozen_circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""An immutable version of the Circuit data structure."""
from typing import (
TYPE_CHECKING,
AbstractSet,
FrozenSet,
Iterable,
Iterator,
Optional,
Sequence,
Tuple,
Union,
)
from typing import TYPE_CHECKING, FrozenSet, Iterable, Iterator, Optional, Sequence, Tuple, Union

import numpy as np

from cirq import ops, protocols
from cirq.circuits import AbstractCircuit, Alignment, Circuit
from cirq.circuits.insert_strategy import InsertStrategy
from cirq.type_workarounds import NotImplementedType

import numpy as np

from cirq import ops, protocols, _compat


if TYPE_CHECKING:
import cirq

Expand Down Expand Up @@ -70,7 +59,7 @@ def __init__(
self._all_qubits: Optional[FrozenSet['cirq.Qid']] = None
self._all_operations: Optional[Tuple[ops.Operation, ...]] = None
self._has_measurements: Optional[bool] = None
self._all_measurement_key_objs: Optional[AbstractSet['cirq.MeasurementKey']] = None
self._all_measurement_key_objs: Optional[FrozenSet['cirq.MeasurementKey']] = None
self._are_all_measurements_terminal: Optional[bool] = None
self._control_keys: Optional[FrozenSet['cirq.MeasurementKey']] = None

Expand Down Expand Up @@ -118,12 +107,12 @@ def has_measurements(self) -> bool:
self._has_measurements = super().has_measurements()
return self._has_measurements

def all_measurement_key_objs(self) -> AbstractSet['cirq.MeasurementKey']:
def all_measurement_key_objs(self) -> FrozenSet['cirq.MeasurementKey']:
if self._all_measurement_key_objs is None:
self._all_measurement_key_objs = super().all_measurement_key_objs()
return self._all_measurement_key_objs

def _measurement_key_objs_(self) -> AbstractSet['cirq.MeasurementKey']:
def _measurement_key_objs_(self) -> FrozenSet['cirq.MeasurementKey']:
return self.all_measurement_key_objs()

def _control_keys_(self) -> FrozenSet['cirq.MeasurementKey']:
Expand All @@ -138,10 +127,10 @@ def are_all_measurements_terminal(self) -> bool:

# End of memoized methods.

def all_measurement_key_names(self) -> AbstractSet[str]:
return {str(key) for key in self.all_measurement_key_objs()}
def all_measurement_key_names(self) -> FrozenSet[str]:
return frozenset(str(key) for key in self.all_measurement_key_objs())

def _measurement_key_names_(self) -> AbstractSet[str]:
def _measurement_key_names_(self) -> FrozenSet[str]:
return self.all_measurement_key_names()

def __add__(self, other) -> 'cirq.FrozenCircuit':
Expand Down
5 changes: 2 additions & 3 deletions cirq-core/cirq/circuits/moment.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import itertools
from typing import (
AbstractSet,
Any,
Callable,
Dict,
Expand Down Expand Up @@ -238,8 +237,8 @@ def _with_measurement_key_mapping_(self, key_map: Dict[str, str]):
for op in self.operations
)

def _measurement_key_names_(self) -> AbstractSet[str]:
return {str(key) for key in self._measurement_key_objs_()}
def _measurement_key_names_(self) -> FrozenSet[str]:
return frozenset(str(key) for key in self._measurement_key_objs_())

def _measurement_key_objs_(self) -> FrozenSet['cirq.MeasurementKey']:
if self._measurement_key_objs is None:
Expand Down
6 changes: 4 additions & 2 deletions cirq-core/cirq/ops/gate_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def _measurement_key_name_(self) -> Optional[str]:
return getter()
return NotImplemented

def _measurement_key_names_(self) -> Optional[AbstractSet[str]]:
def _measurement_key_names_(self) -> Union[FrozenSet[str], NotImplementedType, None]:
getter = getattr(self.gate, '_measurement_key_names_', None)
if getter is not None:
return getter()
Expand All @@ -247,7 +247,9 @@ def _measurement_key_obj_(self) -> Optional['cirq.MeasurementKey']:
return getter()
return NotImplemented

def _measurement_key_objs_(self) -> Optional[AbstractSet['cirq.MeasurementKey']]:
def _measurement_key_objs_(
self,
) -> Union[FrozenSet['cirq.MeasurementKey'], NotImplementedType, None]:
getter = getattr(self.gate, '_measurement_key_objs_', None)
if getter is not None:
return getter()
Expand Down
6 changes: 3 additions & 3 deletions cirq-core/cirq/ops/raw_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,10 +820,10 @@ def _has_kraus_(self) -> bool:
def _kraus_(self) -> Union[Tuple[np.ndarray], NotImplementedType]:
return protocols.kraus(self.sub_operation, NotImplemented)

def _measurement_key_names_(self) -> AbstractSet[str]:
def _measurement_key_names_(self) -> FrozenSet[str]:
return protocols.measurement_key_names(self.sub_operation)

def _measurement_key_objs_(self) -> AbstractSet['cirq.MeasurementKey']:
def _measurement_key_objs_(self) -> FrozenSet['cirq.MeasurementKey']:
return protocols.measurement_key_objs(self.sub_operation)

def _is_measurement_(self) -> bool:
Expand Down Expand Up @@ -905,7 +905,7 @@ def with_classical_controls(
return self
return self.sub_operation.with_classical_controls(*conditions)

def _control_keys_(self) -> AbstractSet['cirq.MeasurementKey']:
def _control_keys_(self) -> FrozenSet['cirq.MeasurementKey']:
return protocols.control_keys(self.sub_operation)


Expand Down
20 changes: 14 additions & 6 deletions cirq-core/cirq/protocols/control_key_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@
# limitations under the License.
"""Protocol for object that have control keys."""

from typing import AbstractSet, Any, Iterable, TYPE_CHECKING
from typing import Any, FrozenSet, TYPE_CHECKING, Union

from typing_extensions import Protocol

from cirq import _compat
from cirq._doc import doc_private
from cirq.protocols import measurement_key_protocol
from cirq.type_workarounds import NotImplementedType

if TYPE_CHECKING:
import cirq
Expand All @@ -34,7 +36,7 @@ class SupportsControlKey(Protocol):
"""

@doc_private
def _control_keys_(self) -> Iterable['cirq.MeasurementKey']:
def _control_keys_(self) -> Union[FrozenSet['cirq.MeasurementKey'], NotImplementedType, None]:
"""Return the keys for controls referenced by the receiving object.
Returns:
Expand All @@ -43,7 +45,7 @@ def _control_keys_(self) -> Iterable['cirq.MeasurementKey']:
"""


def control_keys(val: Any) -> AbstractSet['cirq.MeasurementKey']:
def control_keys(val: Any) -> FrozenSet['cirq.MeasurementKey']:
"""Gets the keys that the value is classically controlled by.
Args:
Expand All @@ -56,12 +58,18 @@ def control_keys(val: Any) -> AbstractSet['cirq.MeasurementKey']:
getter = getattr(val, '_control_keys_', None)
result = NotImplemented if getter is None else getter()
if result is not NotImplemented and result is not None:
return set(result)
if not isinstance(result, FrozenSet):
_compat._warn_or_error(
f'The _control_keys_ implementation of {type(val)} must return a'
f' frozenset instead of {type(result)} by v0.16.'
)
return frozenset(result)
return result

return set()
return frozenset()


def measurement_keys_touched(val: Any) -> AbstractSet['cirq.MeasurementKey']:
def measurement_keys_touched(val: Any) -> FrozenSet['cirq.MeasurementKey']:
"""Returns all the measurement keys used by the value.
This would be the case if the value is or contains a measurement gate, or
Expand Down
11 changes: 10 additions & 1 deletion cirq-core/cirq/protocols/control_key_protocol_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
def test_control_key():
class Named:
def _control_keys_(self):
return [cirq.MeasurementKey('key')]
return frozenset([cirq.MeasurementKey('key')])

class NoImpl:
def _control_keys_(self):
Expand All @@ -27,3 +27,12 @@ def _control_keys_(self):
assert cirq.control_keys(Named()) == {cirq.MeasurementKey('key')}
assert not cirq.control_keys(NoImpl())
assert not cirq.control_keys(5)


def test_control_key_enumerable_deprecated():
class Deprecated:
def _control_keys_(self):
return [cirq.MeasurementKey('key')]

with cirq.testing.assert_deprecated('frozenset', deadline='v0.16'):
assert cirq.control_keys(Deprecated()) == {cirq.MeasurementKey('key')}
Loading

0 comments on commit 4d30756

Please sign in to comment.