Skip to content

Commit 465bfa9

Browse files
95-martin-orionrht
authored andcommitted
Use mapping in cirq/work (quantumlib#5609)
Fixes quantumlib#5554.
1 parent 13a7d77 commit 465bfa9

File tree

4 files changed

+19
-8
lines changed

4 files changed

+19
-8
lines changed

cirq-core/cirq/work/observable_measurement.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -531,7 +531,7 @@ def measure_grouped_settings(
531531
for max_setting, param_resolver in itertools.product(
532532
grouped_settings.keys(), study.to_resolvers(circuit_sweep)
533533
):
534-
circuit_params = dict(param_resolver.param_dict)
534+
circuit_params = param_resolver.param_dict
535535
meas_spec = _MeasurementSpec(max_setting=max_setting, circuit_params=circuit_params)
536536
accumulator = BitstringAccumulator(
537537
meas_spec=meas_spec,

cirq-core/cirq/work/observable_measurement_data.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import dataclasses
1616
import datetime
17-
from typing import Any, Dict, Iterable, List, Tuple, TYPE_CHECKING, Union
17+
from typing import Any, Dict, Iterable, List, Mapping, Tuple, TYPE_CHECKING, Union
1818

1919
import numpy as np
2020
import sympy
@@ -107,7 +107,7 @@ class ObservableMeasuredResult:
107107
mean: float
108108
variance: float
109109
repetitions: int
110-
circuit_params: Dict[Union[str, sympy.Expr], Union[value.Scalar, sympy.Expr]]
110+
circuit_params: Mapping[Union[str, sympy.Expr], Union[value.Scalar, sympy.Expr]]
111111

112112
def __repr__(self):
113113
# I wish we could use the default dataclass __repr__ but

cirq-core/cirq/work/observable_settings.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,17 @@
1414

1515
import dataclasses
1616
import numbers
17-
from typing import Union, Iterable, Dict, Optional, TYPE_CHECKING, ItemsView, Tuple, FrozenSet
17+
from typing import (
18+
AbstractSet,
19+
Mapping,
20+
Union,
21+
Iterable,
22+
Dict,
23+
Optional,
24+
TYPE_CHECKING,
25+
Tuple,
26+
FrozenSet,
27+
)
1828

1929
import sympy
2030

@@ -143,7 +153,8 @@ def _fix_precision(val: Union[value.Scalar, sympy.Expr], precision) -> Union[int
143153

144154

145155
def _hashable_param(
146-
param_tuples: ItemsView[Union[str, sympy.Expr], Union[value.Scalar, sympy.Expr]], precision=1e7
156+
param_tuples: AbstractSet[Tuple[Union[str, sympy.Expr], Union[value.Scalar, sympy.Expr]]],
157+
precision=1e7,
147158
) -> FrozenSet[Tuple[str, Union[int, Tuple[int, int]]]]:
148159
"""Hash circuit parameters using fixed precision.
149160
@@ -166,7 +177,7 @@ class _MeasurementSpec:
166177
"""
167178

168179
max_setting: InitObsSetting
169-
circuit_params: Dict[Union[str, sympy.Expr], Union[value.Scalar, sympy.Expr]]
180+
circuit_params: Mapping[Union[str, sympy.Expr], Union[value.Scalar, sympy.Expr]]
170181

171182
def __hash__(self):
172183
return hash((self.max_setting, _hashable_param(self.circuit_params.items())))

cirq-core/cirq/work/sampler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -352,8 +352,8 @@ def sample_expectation_values(
352352

353353
# Flatten Circuit Sweep into one big list of Params.
354354
# Keep track of their indices so we can map back.
355-
flat_params: List['cirq.ParamDictType'] = [
356-
dict(pr.param_dict) for pr in study.to_resolvers(params)
355+
flat_params: List['cirq.ParamMappingType'] = [
356+
pr.param_dict for pr in study.to_resolvers(params)
357357
]
358358
circuit_param_to_sweep_i: Dict[FrozenSet[Tuple[str, Union[int, Tuple[int, int]]]], int] = {
359359
_hashable_param(param.items()): i for i, param in enumerate(flat_params)

0 commit comments

Comments
 (0)