Skip to content

Commit

Permalink
Improve parameter-related types and documentation (#3330)
Browse files Browse the repository at this point in the history
This is part of the suggestion in #3256 (review).
  • Loading branch information
kevinsung authored Sep 16, 2020
1 parent 946665d commit 11933d1
Show file tree
Hide file tree
Showing 8 changed files with 33 additions and 16 deletions.
1 change: 1 addition & 0 deletions cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,7 @@
PeriodicValue,
RANDOM_STATE_OR_SEED_LIKE,
Timestamp,
TParamKey,
TParamVal,
validate_probability,
value_equality,
Expand Down
1 change: 1 addition & 0 deletions cirq/protocols/json_serialization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ def test_fail_to_resolve():
'RANDOM_STATE_OR_SEED_LIKE',
'STATE_VECTOR_LIKE',
'Sweepable',
'TParamKey',
'TParamVal',
'ParamDictType',

Expand Down
22 changes: 12 additions & 10 deletions cirq/study/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import cirq


ParamDictType = Dict[Union[str, sympy.Symbol], Union[float, str, sympy.Basic]]
ParamDictType = Dict['cirq.TParamKey', 'cirq.TParamVal']
document(
ParamDictType, # type: ignore
"""Dictionary from symbols to values.""")
Expand All @@ -36,10 +36,11 @@


class ParamResolver:
"""Resolves sympy.Symbols to actual values.
"""Resolves parameters to actual values.
A Symbol is a wrapped parameter name (str). A ParamResolver is an object
that can be used to assign values for these keys.
A parameter is a variable whose value has not been determined.
A ParamResolver is an object that can be used to assign values for these
variables.
ParamResolvers are hashable.
Expand All @@ -63,15 +64,17 @@ def __init__(self,
{} if param_dict is None else param_dict)

def value_of(self,
value: Union[sympy.Basic, float, str]) -> 'cirq.TParamVal':
"""Attempt to resolve a Symbol, string, or float to its assigned value.
value: Union['cirq.TParamKey', float]) -> 'cirq.TParamVal':
"""Attempt to resolve a parameter to its assigned value.
Floats are returned without modification. Strings are resolved via
the parameter dictionary with exact match only. Otherwise, strings
are considered to be sympy.Symbols with the name as the input string.
sympy.Symbols are first checked for exact match in the parameter
dictionary. Otherwise, the symbol is resolved using sympy substitution.
A sympy.Symbol is first checked for exact match in the parameter
dictionary. Otherwise, it is treated as a sympy.Basic.
A sympy.Basic is resolved using sympy substitution.
Note that passing a formula to this resolver can be slow due to the
underlying sympy library. For circuits relying on quick performance,
Expand All @@ -81,8 +84,7 @@ def value_of(self,
If unable to resolve a name, returns a sympy.Symbol with that name.
Args:
value: The sympy.Symbol or name or float to try to resolve into just
a float.
value: The parameter to try to resolve.
Returns:
The value of the parameter as resolved by this resolver.
Expand Down
7 changes: 5 additions & 2 deletions cirq/study/sweepable.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,12 @@
from cirq.study.resolver import ParamResolver, ParamResolverOrSimilarType
from cirq.study.sweeps import ListSweep, Points, Sweep, UnitSweep, Zip

SweepLike = Union[ParamResolverOrSimilarType, Sweep]
document(
SweepLike, # type: ignore
"""An object similar to an iterable of parameter resolvers.""")

Sweepable = Union[Dict[str, float], ParamResolver, Sweep, Iterable[
Union[Dict[str, float], ParamResolver, Sweep]], None]
Sweepable = Union[SweepLike, Iterable[SweepLike]]
document(
Sweepable, # type: ignore
"""An object or collection of objects representing a parameter sweep.""")
Expand Down
6 changes: 4 additions & 2 deletions cirq/study/sweeps.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import (Any, cast, Dict, Iterable, Iterator, List, overload,
Sequence, Tuple, Union)
Sequence, TYPE_CHECKING, Tuple, Union)

import abc
import collections
Expand All @@ -22,8 +22,10 @@
from cirq._doc import document
from cirq.study import resolver

if TYPE_CHECKING:
import cirq

Params = Iterable[Tuple[str, float]]
Params = Iterable[Tuple['cirq.TParamKey', 'cirq.TParamVal']]


def _check_duplicate_keys(sweeps):
Expand Down
4 changes: 3 additions & 1 deletion cirq/value/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@
Timestamp,)

from cirq.value.type_alias import (
TParamVal,)
TParamKey,
TParamVal,
)

from cirq.value.value_equality_attr import (
value_equality,)
7 changes: 6 additions & 1 deletion cirq/value/type_alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,12 @@
"""Supply aliases for commonly used types.
"""

TParamKey = Union[str, sympy.Basic]
document(
TParamKey, # type: ignore
"""A parameter that a parameter resolver may map to a value.""")

TParamVal = Union[float, sympy.Basic]
document(
TParamVal, # type: ignore
"""A value that a parameter resolver may return for a symbol.""")
"""A value that a parameter resolver may return for a parameter.""")
1 change: 1 addition & 0 deletions rtd_docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,7 @@ important roles in the internal machinery of the library.
cirq.LinearCombinationOfGates
cirq.LinearCombinationOfOperations
cirq.SingleQubitPauliStringGateOperation
cirq.TParamKey
cirq.TParamVal


Expand Down

0 comments on commit 11933d1

Please sign in to comment.