Skip to content

Commit

Permalink
time to clean up time parsing 😉 (#2770)
Browse files Browse the repository at this point in the history
* time to clean up time parsing

* fix type error

* updates
  • Loading branch information
aspfohl authored Dec 9, 2023
1 parent cb8f937 commit f097fd7
Show file tree
Hide file tree
Showing 8 changed files with 127 additions and 60 deletions.
7 changes: 1 addition & 6 deletions composer/callbacks/activation_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,7 @@ def __init__(self,
self.handles = []

# Check that the interval timestring is parsable and convert into time object
if isinstance(interval, int):
self.interval = Time(interval, TimeUnit.BATCH)
elif isinstance(interval, str):
self.interval = Time.from_timestring(interval)
elif isinstance(interval, Time):
self.interval = interval
self.interval = Time.from_input(interval, TimeUnit.BATCH)

if self.interval.unit == TimeUnit.BATCH and self.interval < Time.from_timestring('10ba'):
warnings.warn(f'Currently the ActivationMonitor`s interval is set to {self.interval} '
Expand Down
7 changes: 2 additions & 5 deletions composer/callbacks/image_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class ImageVisualizer(Callback):
This callback only works with wandb logging for now.
Args:
interval (str | Time, optional): Time string specifying how often to log train images. For example, ``interval='1ep'``
interval (int | str | Time, optional): Time string specifying how often to log train images. For example, ``interval='1ep'``
means images are logged once every epoch, while ``interval='100ba'`` means images are logged once every 100
batches. Eval images are logged once at the start of each eval. Default: ``"100ba"``.
mode (str, optional): How to log the image labels. Valid values are ``"input"`` (input only)
Expand Down Expand Up @@ -86,10 +86,7 @@ def __init__(self,
raise ValueError(f'Invalid mode: {mode}')

# Check that the interval timestring is parsable and convert into time object
if isinstance(interval, int):
self.interval = Time(interval, TimeUnit.BATCH)
if isinstance(interval, str):
self.interval = Time.from_timestring(interval)
self.interval = Time.from_input(interval, TimeUnit.BATCH)

# Verify that the interval has supported units
if self.interval.unit not in [TimeUnit.BATCH, TimeUnit.EPOCH]:
Expand Down
93 changes: 58 additions & 35 deletions composer/core/time.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,20 +223,12 @@ def to_timestring(self):
"""
return str(self)

def _parse(self, other: object) -> Time:
def _parse(self, other: Union[int, float, Time, str]) -> Time:
# parse ``other`` into a Time object
if isinstance(other, Time):
return other
if isinstance(other, int):
return Time(other, self.unit)
if isinstance(other, str):
other_parsed = Time.from_timestring(other)
return other_parsed

raise TypeError(f'Cannot convert type {other} to {self.__class__.__name__}')
return Time.from_input(other, self.unit)

def _cmp(self, other: Union[int, float, Time, str]) -> int:
# When doing comparisions, and other is an integer (or float), we can safely infer
# When doing comparisons, and other is an integer (or float), we can safely infer
# the unit from self.unit
# E.g. calls like this should be allowed: if batch < 42: do_something()
# This eliminates the need to call .value everywhere
Expand Down Expand Up @@ -302,15 +294,21 @@ def __int__(self):
def __float__(self):
return float(self.value)

def __truediv__(self, other: object) -> Time[float]:
def __truediv__(self, other: Union[int, float, Time, str]) -> Time[float]:
if isinstance(other, (float, int)):
return Time(type(self.value)(self.value / other), self.unit)
other = self._parse(other)
if self.unit != other.unit:
raise RuntimeError(f'Cannot divide {self} by {other} since they have different units.')
return Time(self.value / other.value, TimeUnit.DURATION)

def __mul__(self, other: object):
def __mod__(self, other: Union[int, float, Time, str]) -> Time[TValue]:
other = self._parse(other)
if self.unit != other.unit:
raise RuntimeError(f'Cannot take mod of {self} by {other} since they have different units.')
return Time(self.value % other.value, self.unit)

def __mul__(self, other: Union[int, float, Time, str]):
if isinstance(other, (float, int)):
# Scale by the value.
return Time(type(self.value)(self.value * other), self.unit)
Expand All @@ -321,12 +319,43 @@ def __mul__(self, other: object):
real_type = float if real_unit == TimeUnit.DURATION else int
return Time(real_type(self.value * other.value), real_unit)

def __rmul__(self, other: object):
def __rmul__(self, other: Union[int, float, Time, str]):
return self * other

def __hash__(self):
return hash((self.value, self.unit))

@classmethod
def from_input(cls,
i: Union[str, int, float, 'Time'],
default_int_unit: Optional[Union[TimeUnit, str]] = None) -> Time:
"""Parse a time input into a :class:`Time` instance.
Args:
i (str | int | Time): The time input.
default_int_unit (TimeUnit, optional): The default unit to use if ``i`` is an integer
>>> Time.from_input("5ep")
Time(5, TimeUnit.EPOCH)
>>> Time.from_input(5, TimeUnit.EPOCH)
Time(5, TimeUnit.EPOCH)
Returns:
Time: An instance of :class:`Time`.
"""
if isinstance(i, Time):
return i

if isinstance(i, str):
return Time.from_timestring(i)

if isinstance(i, int) or isinstance(i, float):
if default_int_unit is None:
raise RuntimeError('default_int_unit must be specified when constructing Time from an integer.')
return Time(i, default_int_unit)

raise RuntimeError(f'Cannot convert type {i} to {cls.__name__}')

@classmethod
def from_timestring(cls, timestring: str) -> Time:
"""Parse a time string into a :class:`Time` instance.
Expand Down Expand Up @@ -393,39 +422,39 @@ def __init__(
epoch_wct: Optional[datetime.timedelta] = None,
batch_wct: Optional[datetime.timedelta] = None,
):
epoch = ensure_time(epoch, TimeUnit.EPOCH)
epoch = Time.from_input(epoch, TimeUnit.EPOCH)
if epoch.unit != TimeUnit.EPOCH:
raise ValueError(f'The `epoch` argument has units of {epoch.unit}; not {TimeUnit.EPOCH}.')
self._epoch = epoch

batch = ensure_time(batch, TimeUnit.BATCH)
batch = Time.from_input(batch, TimeUnit.BATCH)
if batch.unit != TimeUnit.BATCH:
raise ValueError(f'The `batch` argument has units of {batch.unit}; not {TimeUnit.BATCH}.')
self._batch = batch

sample = ensure_time(sample, TimeUnit.SAMPLE)
sample = Time.from_input(sample, TimeUnit.SAMPLE)
if sample.unit != TimeUnit.SAMPLE:
raise ValueError(f'The `sample` argument has units of {sample.unit}; not {TimeUnit.SAMPLE}.')
self._sample = sample

token = ensure_time(token, TimeUnit.TOKEN)
token = Time.from_input(token, TimeUnit.TOKEN)
if token.unit != TimeUnit.TOKEN:
raise ValueError(f'The `token` argument has units of {token.unit}; not {TimeUnit.TOKEN}.')
self._token = token

batch_in_epoch = ensure_time(batch_in_epoch, TimeUnit.BATCH)
batch_in_epoch = Time.from_input(batch_in_epoch, TimeUnit.BATCH)
if batch_in_epoch.unit != TimeUnit.BATCH:
raise ValueError((f'The `batch_in_epoch` argument has units of {batch_in_epoch.unit}; '
f'not {TimeUnit.BATCH}.'))
self._batch_in_epoch = batch_in_epoch

sample_in_epoch = ensure_time(sample_in_epoch, TimeUnit.SAMPLE)
sample_in_epoch = Time.from_input(sample_in_epoch, TimeUnit.SAMPLE)
if sample_in_epoch.unit != TimeUnit.SAMPLE:
raise ValueError((f'The `sample_in_epoch` argument has units of {sample_in_epoch.unit}; '
f'not {TimeUnit.SAMPLE}.'))
self._sample_in_epoch = sample_in_epoch

token_in_epoch = ensure_time(token_in_epoch, TimeUnit.TOKEN)
token_in_epoch = Time.from_input(token_in_epoch, TimeUnit.TOKEN)
if token_in_epoch.unit != TimeUnit.TOKEN:
raise ValueError((f'The `token_in_epoch` argument has units of {token_in_epoch.unit}; '
f'not {TimeUnit.TOKEN}.'))
Expand Down Expand Up @@ -563,7 +592,7 @@ def get(self, unit: Union[str, TimeUnit]) -> Time[int]:
return self.token
raise ValueError(f'Invalid unit: {unit}')

def _parse(self, other: object) -> Time:
def _parse(self, other: Union[int, float, Time, str]) -> Time:
# parse ``other`` into a Time object
if isinstance(other, Time):
return other
Expand All @@ -573,7 +602,7 @@ def _parse(self, other: object) -> Time:

raise TypeError(f'Cannot convert type {other} to {self.__class__.__name__}')

def __eq__(self, other: object):
def __eq__(self, other: Union[int, float, Time, str]):
if not isinstance(other, (Time, Timestamp, str)):
return NotImplemented
if isinstance(other, Timestamp):
Expand All @@ -582,7 +611,7 @@ def __eq__(self, other: object):
self_counter = self.get(other.unit)
return self_counter == other

def __ne__(self, other: object):
def __ne__(self, other: Union[int, float, Time, str]):
if not isinstance(other, (Time, Timestamp, str)):
return NotImplemented
if isinstance(other, Timestamp):
Expand All @@ -591,28 +620,28 @@ def __ne__(self, other: object):
self_counter = self.get(other.unit)
return self_counter != other

def __lt__(self, other: object):
def __lt__(self, other: Union[int, float, Time, str]):
if not isinstance(other, (Time, str)):
return NotImplemented
other = self._parse(other)
self_counter = self.get(other.unit)
return self_counter < other

def __le__(self, other: object):
def __le__(self, other: Union[int, float, Time, str]):
if not isinstance(other, (Time, str)):
return NotImplemented
other = self._parse(other)
self_counter = self.get(other.unit)
return self_counter <= other

def __gt__(self, other: object):
def __gt__(self, other: Union[int, float, Time, str]):
if not isinstance(other, (Time, str)):
return NotImplemented
other = self._parse(other)
self_counter = self.get(other.unit)
return self_counter > other

def __ge__(self, other: object):
def __ge__(self, other: Union[int, float, Time, str]):
if not isinstance(other, (Time, str)):
return NotImplemented
other = self._parse(other)
Expand Down Expand Up @@ -783,10 +812,4 @@ def ensure_time(maybe_time: Union[Time, str, int], int_unit: Union[TimeUnit, str
Returns:
Time: An instance of :class:`.Time`.
"""
if isinstance(maybe_time, str):
return Time.from_timestring(maybe_time)
if isinstance(maybe_time, int):
return Time(maybe_time, int_unit)
if isinstance(maybe_time, Time):
return maybe_time
raise TypeError(f'Unsupported type for ensure_time: {type(maybe_time)}')
return Time.from_input(maybe_time, int_unit)
6 changes: 1 addition & 5 deletions composer/loggers/console_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,7 @@ def __init__(self,
stream: Union[str, TextIO] = sys.stderr,
log_traces: bool = False) -> None:

if isinstance(log_interval, int):
log_interval = Time(log_interval, TimeUnit.EPOCH)
if isinstance(log_interval, str):
log_interval = Time.from_timestring(log_interval)

log_interval = Time.from_input(log_interval, TimeUnit.EPOCH)
self.last_logged_batch = 0

if log_interval.unit not in (TimeUnit.EPOCH, TimeUnit.BATCH):
Expand Down
5 changes: 1 addition & 4 deletions composer/loggers/slack_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,7 @@ def __init__(
# Create a regex of all keys to include
self.regex_all_keys = '(' + ')|('.join(include_keys) + ')'

if isinstance(log_interval, int):
self.log_interval = Time(log_interval, TimeUnit.EPOCH)
if isinstance(log_interval, str):
self.log_interval = Time.from_timestring(log_interval)
self.log_interval = Time.from_input(log_interval, TimeUnit.EPOCH)
if self.log_interval.unit not in (TimeUnit.EPOCH, TimeUnit.BATCH):
raise ValueError('The `slack logger log_interval` argument must have units of EPOCH or BATCH.')

Expand Down
5 changes: 5 additions & 0 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2017,6 +2017,11 @@ def _train_loop(self) -> None:
finished_epoch_early = False
last_wct = datetime.datetime.now()

if self.state.max_duration is None:
# This is essentially just a type check, as max_duration should always be
# asserted to be not None when Trainer.fit() is called
raise RuntimeError('max_duration must be specified when initializing the Trainer')

while self.state.timestamp < self.state.max_duration:
if int(self.state.timestamp.batch_in_epoch) == 0:
self.engine.run_event(Event.EPOCH_START)
Expand Down
6 changes: 1 addition & 5 deletions composer/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,7 @@ def create_interval_scheduler(interval: Union[str, int, 'Time'],
if final_events is None:
final_events = {Event.BATCH_CHECKPOINT, Event.EPOCH_CHECKPOINT}

if isinstance(interval, str):
interval = Time.from_timestring(interval)
if isinstance(interval, int):
interval = Time(interval, TimeUnit.EPOCH)

interval = Time.from_input(interval, TimeUnit.EPOCH)
if interval.unit == TimeUnit.EPOCH:
interval_event = Event.EPOCH_CHECKPOINT if checkpoint_events else Event.EPOCH_END
elif interval.unit in {TimeUnit.BATCH, TimeUnit.TOKEN, TimeUnit.SAMPLE, TimeUnit.DURATION}:
Expand Down
58 changes: 58 additions & 0 deletions tests/test_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,64 @@ def test_time_math():
assert t4 * 2 == Time.from_timestring('1dur')
assert t1 / t2 == t4
assert t2 / 2 == t1
assert t3 % t3 == Time.from_timestring('0ep')
assert t3 % t2 == t1


def test_invalid_math():
t1 = Time.from_timestring('1ep')
t2 = Time.from_timestring('1ba')

with pytest.raises(RuntimeError):
_ = t1 > t2

with pytest.raises(RuntimeError):
_ = t1 < t2

with pytest.raises(RuntimeError):
_ = t1 >= t2

with pytest.raises(RuntimeError):
_ = t1 <= t2

with pytest.raises(RuntimeError):
_ = t1 == t2

with pytest.raises(RuntimeError):
_ = t1 != t2

with pytest.raises(RuntimeError):
_ = t1 + t2

with pytest.raises(RuntimeError):
_ = t1 - t2

with pytest.raises(RuntimeError):
_ = t1 / t2

with pytest.raises(RuntimeError):
_ = t1 % t2

with pytest.raises(RuntimeError):
_ = t1 * t2


def test_time_from_input():
expected = Time(1, TimeUnit.EPOCH)

assert Time.from_input(expected) == expected
assert Time.from_input('1ep') == expected
assert Time.from_input(1, TimeUnit.EPOCH) == expected
assert Time.from_input(1, 'ep') == expected

with pytest.raises(RuntimeError):
Time.from_input(None) # type: ignore

with pytest.raises(RuntimeError):
Time.from_input([123]) # type: ignore

with pytest.raises(RuntimeError):
Time.from_input(1)


def test_time_repr():
Expand Down

0 comments on commit f097fd7

Please sign in to comment.