Skip to content

Commit

Permalink
Speed up parameter resolution for cirq.Duration (#6270)
Browse files Browse the repository at this point in the history
  • Loading branch information
tanujkhattar authored Aug 30, 2023
1 parent c7048f5 commit b28bfce
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 37 deletions.
105 changes: 68 additions & 37 deletions cirq-core/cirq/value/duration.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
# limitations under the License.
"""A typed time delta that supports picosecond accuracy."""

from typing import AbstractSet, Any, Dict, Optional, Tuple, TYPE_CHECKING, Union
from typing import AbstractSet, Any, Dict, Optional, Tuple, TYPE_CHECKING, Union, List
import datetime

import sympy
import numpy as np

from cirq import protocols
from cirq._compat import proper_repr
from cirq._compat import proper_repr, cached_method
from cirq._doc import document

if TYPE_CHECKING:
Expand Down Expand Up @@ -79,48 +79,53 @@ def __init__(
>>> print(cirq.Duration(micros=1.5 * sympy.Symbol('t')))
(1500.0*t) ns
"""
self._time_vals: List[_NUMERIC_INPUT_TYPE] = [0, 0, 0, 0]
self._multipliers = [1, 1000, 1000_000, 1000_000_000]
if value is not None and value != 0:
if isinstance(value, datetime.timedelta):
# timedelta has microsecond resolution.
micros += int(value / datetime.timedelta(microseconds=1))
self._time_vals[2] = int(value / datetime.timedelta(microseconds=1))
elif isinstance(value, Duration):
picos += value._picos
self._time_vals = value._time_vals
else:
raise TypeError(f'Not a `cirq.DURATION_LIKE`: {repr(value)}.')

val = picos + nanos * 1000 + micros * 1000_000 + millis * 1000_000_000
self._picos: _NUMERIC_OUTPUT_TYPE = float(val) if isinstance(val, np.number) else val
input_vals = [picos, nanos, micros, millis]
self._time_vals = _add_time_vals(self._time_vals, input_vals)

def _is_parameterized_(self) -> bool:
return protocols.is_parameterized(self._picos)
return protocols.is_parameterized(self._time_vals)

def _parameter_names_(self) -> AbstractSet[str]:
return protocols.parameter_names(self._picos)
return protocols.parameter_names(self._time_vals)

def _resolve_parameters_(self, resolver: 'cirq.ParamResolver', recursive: bool) -> 'Duration':
return Duration(picos=protocols.resolve_parameters(self._picos, resolver, recursive))
return _duration_from_time_vals(
protocols.resolve_parameters(self._time_vals, resolver, recursive)
)

@cached_method
def total_picos(self) -> _NUMERIC_OUTPUT_TYPE:
"""Returns the number of picoseconds that the duration spans."""
return self._picos
val = sum(a * b for a, b in zip(self._time_vals, self._multipliers))
return float(val) if isinstance(val, np.number) else val

def total_nanos(self) -> _NUMERIC_OUTPUT_TYPE:
"""Returns the number of nanoseconds that the duration spans."""
return self._picos / 1000
return self.total_picos() / 1000

def total_micros(self) -> _NUMERIC_OUTPUT_TYPE:
"""Returns the number of microseconds that the duration spans."""
return self._picos / 1000_000
return self.total_picos() / 1000_000

def total_millis(self) -> _NUMERIC_OUTPUT_TYPE:
"""Returns the number of milliseconds that the duration spans."""
return self._picos / 1000_000_000
return self.total_picos() / 1000_000_000

def __add__(self, other) -> 'Duration':
other = _attempt_duration_like_to_duration(other)
if other is None:
return NotImplemented
return Duration(picos=self._picos + other._picos)
return _duration_from_time_vals(_add_time_vals(self._time_vals, other._time_vals))

def __radd__(self, other) -> 'Duration':
return self.__add__(other)
Expand All @@ -129,86 +134,94 @@ def __sub__(self, other) -> 'Duration':
other = _attempt_duration_like_to_duration(other)
if other is None:
return NotImplemented
return Duration(picos=self._picos - other._picos)
return _duration_from_time_vals(
_add_time_vals(self._time_vals, [-x for x in other._time_vals])
)

def __rsub__(self, other) -> 'Duration':
other = _attempt_duration_like_to_duration(other)
if other is None:
return NotImplemented
return Duration(picos=other._picos - self._picos)
return _duration_from_time_vals(
_add_time_vals(other._time_vals, [-x for x in self._time_vals])
)

def __mul__(self, other) -> 'Duration':
if not isinstance(other, (int, float, sympy.Expr)):
return NotImplemented
return Duration(picos=self._picos * other)
if other == 0:
return _duration_from_time_vals([0] * 4)
return _duration_from_time_vals([x * other for x in self._time_vals])

def __rmul__(self, other) -> 'Duration':
return self.__mul__(other)

def __truediv__(self, other) -> Union['Duration', float]:
if isinstance(other, (int, float, sympy.Expr)):
return Duration(picos=self._picos / other)
new_time_vals = [x / other for x in self._time_vals]
return _duration_from_time_vals(new_time_vals)

other_duration = _attempt_duration_like_to_duration(other)
if other_duration is not None:
return self._picos / other_duration._picos
return self.total_picos() / other_duration.total_picos()

return NotImplemented

def __eq__(self, other):
other = _attempt_duration_like_to_duration(other)
if other is None:
return NotImplemented
return self._picos == other._picos
return self.total_picos() == other.total_picos()

def __ne__(self, other):
other = _attempt_duration_like_to_duration(other)
if other is None:
return NotImplemented
return self._picos != other._picos
return self.total_picos() != other.total_picos()

def __gt__(self, other):
other = _attempt_duration_like_to_duration(other)
if other is None:
return NotImplemented
return self._picos > other._picos
return self.total_picos() > other.total_picos()

def __lt__(self, other):
other = _attempt_duration_like_to_duration(other)
if other is None:
return NotImplemented
return self._picos < other._picos
return self.total_picos() < other.total_picos()

def __ge__(self, other):
other = _attempt_duration_like_to_duration(other)
if other is None:
return NotImplemented
return self._picos >= other._picos
return self.total_picos() >= other.total_picos()

def __le__(self, other):
other = _attempt_duration_like_to_duration(other)
if other is None:
return NotImplemented
return self._picos <= other._picos
return self.total_picos() <= other.total_picos()

def __bool__(self):
return bool(self._picos)
return bool(self.total_picos())

def __hash__(self):
if isinstance(self._picos, (int, float)) and self._picos % 1000000 == 0:
return hash(datetime.timedelta(microseconds=self._picos / 1000000))
return hash((Duration, self._picos))
if isinstance(self.total_picos(), (int, float)) and self.total_picos() % 1000000 == 0:
return hash(datetime.timedelta(microseconds=self.total_picos() / 1000000))
return hash((Duration, self.total_picos()))

def _decompose_into_amount_unit_suffix(self) -> Tuple[int, str, str]:
picos = self.total_picos()
if (
isinstance(self._picos, sympy.Mul)
and len(self._picos.args) == 2
and isinstance(self._picos.args[0], (sympy.Integer, sympy.Float))
isinstance(picos, sympy.Mul)
and len(picos.args) == 2
and isinstance(picos.args[0], (sympy.Integer, sympy.Float))
):
scale = self._picos.args[0]
rest = self._picos.args[1]
scale = picos.args[0]
rest = picos.args[1]
else:
scale = self._picos
scale = picos
rest = 1

if scale % 1000_000_000 == 0:
Expand All @@ -234,7 +247,7 @@ def _decompose_into_amount_unit_suffix(self) -> Tuple[int, str, str]:
return amount * rest, unit, suffix

def __str__(self) -> str:
if self._picos == 0:
if self.total_picos() == 0:
return 'Duration(0)'
amount, _, suffix = self._decompose_into_amount_unit_suffix()
if not isinstance(amount, (int, float, sympy.Symbol)):
Expand All @@ -257,3 +270,21 @@ def _attempt_duration_like_to_duration(value: Any) -> Optional[Duration]:
if isinstance(value, (int, float)) and value == 0:
return Duration()
return None


def _add_time_vals(
val1: List[_NUMERIC_INPUT_TYPE], val2: List[_NUMERIC_INPUT_TYPE]
) -> List[_NUMERIC_INPUT_TYPE]:
ret: List[_NUMERIC_INPUT_TYPE] = []
for i in range(4):
if val1[i] and val2[i]:
ret.append(val1[i] + val2[i])
else:
ret.append(val1[i] or val2[i])
return ret


def _duration_from_time_vals(time_vals: List[_NUMERIC_INPUT_TYPE]):
ret = Duration()
ret._time_vals = time_vals
return ret
2 changes: 2 additions & 0 deletions cirq-core/cirq/value/duration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,11 @@ def test_sub():
def test_mul():
assert Duration(picos=2) * 3 == Duration(picos=6)
assert 4 * Duration(picos=3) == Duration(picos=12)
assert 0 * Duration(picos=10) == Duration()

t = sympy.Symbol('t')
assert t * Duration(picos=3) == Duration(picos=3 * t)
assert 0 * Duration(picos=t) == Duration(picos=0)

with pytest.raises(TypeError):
_ = Duration() * Duration()
Expand Down

0 comments on commit b28bfce

Please sign in to comment.