Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lock down CircuitOperation and ParamResolver #5548

Merged
1 change: 1 addition & 0 deletions cirq-core/cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,7 @@
Linspace,
ListSweep,
ParamDictType,
ParamMappingType,
ParamResolver,
ParamResolverOrSimilarType,
Points,
Expand Down
387 changes: 226 additions & 161 deletions cirq-core/cirq/circuits/circuit_operation.py

Large diffs are not rendered by default.

6 changes: 4 additions & 2 deletions cirq-core/cirq/circuits/circuit_operation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -994,8 +994,10 @@ def test_keys_under_parent_path():
assert cirq.measurement_key_names(op1) == {'A'}
op2 = op1.with_key_path(('B',))
assert cirq.measurement_key_names(op2) == {'B:A'}
op3 = op2.repeat(2)
assert cirq.measurement_key_names(op3) == {'B:0:A', 'B:1:A'}
op3 = cirq.with_key_path_prefix(op2, ('C',))
assert cirq.measurement_key_names(op3) == {'C:B:A'}
op4 = op3.repeat(2)
assert cirq.measurement_key_names(op4) == {'C:B:0:A', 'C:B:1:A'}


def test_mapped_circuit_preserves_moments():
Expand Down
1 change: 1 addition & 0 deletions cirq-core/cirq/protocols/json_test_data/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@
'TParamValComplex',
'TRANSFORMER',
'ParamDictType',
'ParamMappingType',
# utility:
'CliffordSimulator',
'NoiseModelFromNoiseProperties',
Expand Down
7 changes: 6 additions & 1 deletion cirq-core/cirq/study/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@
flatten_with_sweep,
)

from cirq.study.resolver import ParamDictType, ParamResolver, ParamResolverOrSimilarType
from cirq.study.resolver import (
ParamDictType,
ParamMappingType,
ParamResolver,
ParamResolverOrSimilarType,
)

from cirq.study.sweepable import Sweepable, to_resolvers, to_sweep, to_sweeps

Expand Down
6 changes: 3 additions & 3 deletions cirq-core/cirq/study/flatten_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def value_of(
return out
# Create a new symbol
symbol = self._next_symbol(value)
self.param_dict[value] = symbol
self._param_dict[value] = symbol
self._taken_symbols.add(symbol)
return symbol

Expand All @@ -292,9 +292,9 @@ def __bool__(self) -> bool:

def __repr__(self) -> str:
if self.get_param_name == self.default_get_param_name:
return f'_ParamFlattener({self.param_dict!r})'
return f'_ParamFlattener({self._param_dict!r})'
else:
return f'_ParamFlattener({self.param_dict!r}, get_param_name={self.get_param_name!r})'
return f'_ParamFlattener({self._param_dict!r}, get_param_name={self.get_param_name!r})'

def flatten(self, val: Any) -> Any:
"""Returns a copy of `val` with any symbols or expressions replaced with
Expand Down
12 changes: 9 additions & 3 deletions cirq-core/cirq/study/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

"""Resolves ParameterValues to assigned values."""
import numbers
from typing import Any, Dict, Iterator, Optional, TYPE_CHECKING, Union, cast
from typing import Any, Dict, Iterator, Mapping, Optional, TYPE_CHECKING, Union, cast

import numpy as np
import sympy
Expand All @@ -27,9 +27,11 @@


ParamDictType = Dict['cirq.TParamKey', 'cirq.TParamValComplex']
ParamMappingType = Mapping['cirq.TParamKey', 'cirq.TParamValComplex']
document(ParamDictType, """Dictionary from symbols to values.""") # type: ignore
document(ParamMappingType, """Immutable map from symbols to values.""") # type: ignore

ParamResolverOrSimilarType = Union['cirq.ParamResolver', ParamDictType, None]
ParamResolverOrSimilarType = Union['cirq.ParamResolver', ParamMappingType, None]
document(
ParamResolverOrSimilarType, # type: ignore
"""Something that can be used to turn parameters into values.""",
Expand Down Expand Up @@ -70,12 +72,16 @@ def __init__(self, param_dict: 'cirq.ParamResolverOrSimilarType' = None) -> None
return # Already initialized. Got wrapped as part of the __new__.

self._param_hash: Optional[int] = None
self.param_dict = cast(ParamDictType, {} if param_dict is None else param_dict)
self._param_dict = cast(ParamDictType, {} if param_dict is None else param_dict)
for key in self.param_dict:
if isinstance(key, sympy.Expr) and not isinstance(key, sympy.Symbol):
raise TypeError(f'ParamResolver keys cannot be (non-symbol) formulas ({key})')
self._deep_eval_map: ParamDictType = {}

@property
def param_dict(self) -> ParamMappingType:
return self._param_dict

def value_of(
self, value: Union['cirq.TParamKey', 'cirq.TParamValComplex'], recursive: bool = True
) -> 'cirq.TParamValComplex':
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/work/observable_measurement.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@ def measure_grouped_settings(
for max_setting, param_resolver in itertools.product(
grouped_settings.keys(), study.to_resolvers(circuit_sweep)
):
circuit_params = param_resolver.param_dict
circuit_params = dict(param_resolver.param_dict)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's file an issue about fixing the types so we can avoid the copies here and in sample_expectation_values below. I think these should be fixable without worrying too much about compatibility, since AFAICT the dicts are not exposed anywhere in a way that is intended to be mutable, but we can deal with that in a future PR.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Opened #5554.

meas_spec = _MeasurementSpec(max_setting=max_setting, circuit_params=circuit_params)
accumulator = BitstringAccumulator(
meas_spec=meas_spec,
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/work/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ def sample_expectation_values(
# Flatten Circuit Sweep into one big list of Params.
# Keep track of their indices so we can map back.
flat_params: List['cirq.ParamDictType'] = [
pr.param_dict for pr in study.to_resolvers(params)
dict(pr.param_dict) for pr in study.to_resolvers(params)
]
circuit_param_to_sweep_i: Dict[FrozenSet[Tuple[str, Union[int, Tuple[int, int]]]], int] = {
_hashable_param(param.items()): i for i, param in enumerate(flat_params)
Expand Down
2 changes: 1 addition & 1 deletion cirq-rigetti/cirq_rigetti/circuit_sweep_executors.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def _get_param_dict(resolver: cirq.ParamResolverOrSimilarType) -> Dict[Union[str
"""
param_dict: Dict[Union[str, sympy.Expr], Any] = {}
if isinstance(resolver, cirq.ParamResolver):
param_dict = resolver.param_dict
param_dict = dict(resolver.param_dict)
elif isinstance(resolver, dict):
param_dict = resolver
return param_dict
Expand Down
2 changes: 1 addition & 1 deletion cirq-rigetti/cirq_rigetti/circuit_sweep_executors_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def test_with_quilc_parametric_compilation(

param_resolvers: List[Union[cirq.ParamResolver, cirq.ParamDictType]]
if pass_dict:
param_resolvers = [params.param_dict for params in sweepable]
param_resolvers = [dict(params.param_dict) for params in sweepable]
else:
param_resolvers = [r for r in cirq.to_resolvers(sweepable)]
expected_results = [
Expand Down