Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

time to clean up time parsing 😉 #2770

Merged
merged 4 commits into from
Dec 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions composer/callbacks/activation_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,7 @@ def __init__(self,
self.handles = []

# Check that the interval timestring is parsable and convert into time object
if isinstance(interval, int):
self.interval = Time(interval, TimeUnit.BATCH)
elif isinstance(interval, str):
self.interval = Time.from_timestring(interval)
elif isinstance(interval, Time):
self.interval = interval
self.interval = Time.from_input(interval, TimeUnit.BATCH)

if self.interval.unit == TimeUnit.BATCH and self.interval < Time.from_timestring('10ba'):
warnings.warn(f'Currently the ActivationMonitor`s interval is set to {self.interval} '
Expand Down
7 changes: 2 additions & 5 deletions composer/callbacks/image_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class ImageVisualizer(Callback):
This callback only works with wandb logging for now.

Args:
interval (str | Time, optional): Time string specifying how often to log train images. For example, ``interval='1ep'``
interval (int | str | Time, optional): Time string specifying how often to log train images. For example, ``interval='1ep'``
means images are logged once every epoch, while ``interval='100ba'`` means images are logged once every 100
batches. Eval images are logged once at the start of each eval. Default: ``"100ba"``.
mode (str, optional): How to log the image labels. Valid values are ``"input"`` (input only)
Expand Down Expand Up @@ -86,10 +86,7 @@ def __init__(self,
raise ValueError(f'Invalid mode: {mode}')

# Check that the interval timestring is parsable and convert into time object
if isinstance(interval, int):
self.interval = Time(interval, TimeUnit.BATCH)
if isinstance(interval, str):
self.interval = Time.from_timestring(interval)
self.interval = Time.from_input(interval, TimeUnit.BATCH)

# Verify that the interval has supported units
if self.interval.unit not in [TimeUnit.BATCH, TimeUnit.EPOCH]:
Expand Down
93 changes: 58 additions & 35 deletions composer/core/time.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,20 +223,12 @@ def to_timestring(self):
"""
return str(self)

def _parse(self, other: object) -> Time:
def _parse(self, other: Union[int, float, Time, str]) -> Time:
# parse ``other`` into a Time object
if isinstance(other, Time):
return other
if isinstance(other, int):
return Time(other, self.unit)
if isinstance(other, str):
other_parsed = Time.from_timestring(other)
return other_parsed

raise TypeError(f'Cannot convert type {other} to {self.__class__.__name__}')
return Time.from_input(other, self.unit)

def _cmp(self, other: Union[int, float, Time, str]) -> int:
# When doing comparisions, and other is an integer (or float), we can safely infer
# When doing comparisons, and other is an integer (or float), we can safely infer
# the unit from self.unit
# E.g. calls like this should be allowed: if batch < 42: do_something()
# This eliminates the need to call .value everywhere
Expand Down Expand Up @@ -302,15 +294,21 @@ def __int__(self):
def __float__(self):
return float(self.value)

def __truediv__(self, other: object) -> Time[float]:
def __truediv__(self, other: Union[int, float, Time, str]) -> Time[float]:
if isinstance(other, (float, int)):
return Time(type(self.value)(self.value / other), self.unit)
other = self._parse(other)
if self.unit != other.unit:
raise RuntimeError(f'Cannot divide {self} by {other} since they have different units.')
return Time(self.value / other.value, TimeUnit.DURATION)

def __mul__(self, other: object):
def __mod__(self, other: Union[int, float, Time, str]) -> Time[TValue]:
other = self._parse(other)
if self.unit != other.unit:
raise RuntimeError(f'Cannot take mod of {self} by {other} since they have different units.')
return Time(self.value % other.value, self.unit)

def __mul__(self, other: Union[int, float, Time, str]):
if isinstance(other, (float, int)):
# Scale by the value.
return Time(type(self.value)(self.value * other), self.unit)
Expand All @@ -321,12 +319,43 @@ def __mul__(self, other: object):
real_type = float if real_unit == TimeUnit.DURATION else int
return Time(real_type(self.value * other.value), real_unit)

def __rmul__(self, other: object):
def __rmul__(self, other: Union[int, float, Time, str]):
return self * other

def __hash__(self):
return hash((self.value, self.unit))

@classmethod
def from_input(cls,
i: Union[str, int, float, 'Time'],
default_int_unit: Optional[Union[TimeUnit, str]] = None) -> Time:
"""Parse a time input into a :class:`Time` instance.

Args:
i (str | int | Time): The time input.
default_int_unit (TimeUnit, optional): The default unit to use if ``i`` is an integer

>>> Time.from_input("5ep")
Time(5, TimeUnit.EPOCH)
>>> Time.from_input(5, TimeUnit.EPOCH)
Time(5, TimeUnit.EPOCH)

Returns:
Time: An instance of :class:`Time`.
"""
if isinstance(i, Time):
return i

if isinstance(i, str):
return Time.from_timestring(i)

if isinstance(i, int) or isinstance(i, float):
if default_int_unit is None:
raise RuntimeError('default_int_unit must be specified when constructing Time from an integer.')
return Time(i, default_int_unit)

raise RuntimeError(f'Cannot convert type {i} to {cls.__name__}')

@classmethod
def from_timestring(cls, timestring: str) -> Time:
"""Parse a time string into a :class:`Time` instance.
Expand Down Expand Up @@ -393,39 +422,39 @@ def __init__(
epoch_wct: Optional[datetime.timedelta] = None,
batch_wct: Optional[datetime.timedelta] = None,
):
epoch = ensure_time(epoch, TimeUnit.EPOCH)
epoch = Time.from_input(epoch, TimeUnit.EPOCH)
if epoch.unit != TimeUnit.EPOCH:
raise ValueError(f'The `epoch` argument has units of {epoch.unit}; not {TimeUnit.EPOCH}.')
self._epoch = epoch

batch = ensure_time(batch, TimeUnit.BATCH)
batch = Time.from_input(batch, TimeUnit.BATCH)
if batch.unit != TimeUnit.BATCH:
raise ValueError(f'The `batch` argument has units of {batch.unit}; not {TimeUnit.BATCH}.')
self._batch = batch

sample = ensure_time(sample, TimeUnit.SAMPLE)
sample = Time.from_input(sample, TimeUnit.SAMPLE)
if sample.unit != TimeUnit.SAMPLE:
raise ValueError(f'The `sample` argument has units of {sample.unit}; not {TimeUnit.SAMPLE}.')
self._sample = sample

token = ensure_time(token, TimeUnit.TOKEN)
token = Time.from_input(token, TimeUnit.TOKEN)
if token.unit != TimeUnit.TOKEN:
raise ValueError(f'The `token` argument has units of {token.unit}; not {TimeUnit.TOKEN}.')
self._token = token

batch_in_epoch = ensure_time(batch_in_epoch, TimeUnit.BATCH)
batch_in_epoch = Time.from_input(batch_in_epoch, TimeUnit.BATCH)
if batch_in_epoch.unit != TimeUnit.BATCH:
raise ValueError((f'The `batch_in_epoch` argument has units of {batch_in_epoch.unit}; '
f'not {TimeUnit.BATCH}.'))
self._batch_in_epoch = batch_in_epoch

sample_in_epoch = ensure_time(sample_in_epoch, TimeUnit.SAMPLE)
sample_in_epoch = Time.from_input(sample_in_epoch, TimeUnit.SAMPLE)
if sample_in_epoch.unit != TimeUnit.SAMPLE:
raise ValueError((f'The `sample_in_epoch` argument has units of {sample_in_epoch.unit}; '
f'not {TimeUnit.SAMPLE}.'))
self._sample_in_epoch = sample_in_epoch

token_in_epoch = ensure_time(token_in_epoch, TimeUnit.TOKEN)
token_in_epoch = Time.from_input(token_in_epoch, TimeUnit.TOKEN)
if token_in_epoch.unit != TimeUnit.TOKEN:
raise ValueError((f'The `token_in_epoch` argument has units of {token_in_epoch.unit}; '
f'not {TimeUnit.TOKEN}.'))
Expand Down Expand Up @@ -563,7 +592,7 @@ def get(self, unit: Union[str, TimeUnit]) -> Time[int]:
return self.token
raise ValueError(f'Invalid unit: {unit}')

def _parse(self, other: object) -> Time:
def _parse(self, other: Union[int, float, Time, str]) -> Time:
# parse ``other`` into a Time object
if isinstance(other, Time):
return other
Expand All @@ -573,7 +602,7 @@ def _parse(self, other: object) -> Time:

raise TypeError(f'Cannot convert type {other} to {self.__class__.__name__}')

def __eq__(self, other: object):
def __eq__(self, other: Union[int, float, Time, str]):
if not isinstance(other, (Time, Timestamp, str)):
return NotImplemented
if isinstance(other, Timestamp):
Expand All @@ -582,7 +611,7 @@ def __eq__(self, other: object):
self_counter = self.get(other.unit)
return self_counter == other

def __ne__(self, other: object):
def __ne__(self, other: Union[int, float, Time, str]):
if not isinstance(other, (Time, Timestamp, str)):
return NotImplemented
if isinstance(other, Timestamp):
Expand All @@ -591,28 +620,28 @@ def __ne__(self, other: object):
self_counter = self.get(other.unit)
return self_counter != other

def __lt__(self, other: object):
def __lt__(self, other: Union[int, float, Time, str]):
if not isinstance(other, (Time, str)):
return NotImplemented
other = self._parse(other)
self_counter = self.get(other.unit)
return self_counter < other

def __le__(self, other: object):
def __le__(self, other: Union[int, float, Time, str]):
if not isinstance(other, (Time, str)):
return NotImplemented
other = self._parse(other)
self_counter = self.get(other.unit)
return self_counter <= other

def __gt__(self, other: object):
def __gt__(self, other: Union[int, float, Time, str]):
if not isinstance(other, (Time, str)):
return NotImplemented
other = self._parse(other)
self_counter = self.get(other.unit)
return self_counter > other

def __ge__(self, other: object):
def __ge__(self, other: Union[int, float, Time, str]):
if not isinstance(other, (Time, str)):
return NotImplemented
other = self._parse(other)
Expand Down Expand Up @@ -783,10 +812,4 @@ def ensure_time(maybe_time: Union[Time, str, int], int_unit: Union[TimeUnit, str
Returns:
Time: An instance of :class:`.Time`.
"""
if isinstance(maybe_time, str):
return Time.from_timestring(maybe_time)
if isinstance(maybe_time, int):
return Time(maybe_time, int_unit)
if isinstance(maybe_time, Time):
return maybe_time
raise TypeError(f'Unsupported type for ensure_time: {type(maybe_time)}')
return Time.from_input(maybe_time, int_unit)
6 changes: 1 addition & 5 deletions composer/loggers/console_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,7 @@ def __init__(self,
stream: Union[str, TextIO] = sys.stderr,
log_traces: bool = False) -> None:

if isinstance(log_interval, int):
log_interval = Time(log_interval, TimeUnit.EPOCH)
if isinstance(log_interval, str):
log_interval = Time.from_timestring(log_interval)

log_interval = Time.from_input(log_interval, TimeUnit.EPOCH)
self.last_logged_batch = 0

if log_interval.unit not in (TimeUnit.EPOCH, TimeUnit.BATCH):
Expand Down
5 changes: 1 addition & 4 deletions composer/loggers/slack_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,7 @@ def __init__(
# Create a regex of all keys to include
self.regex_all_keys = '(' + ')|('.join(include_keys) + ')'

if isinstance(log_interval, int):
self.log_interval = Time(log_interval, TimeUnit.EPOCH)
if isinstance(log_interval, str):
self.log_interval = Time.from_timestring(log_interval)
self.log_interval = Time.from_input(log_interval, TimeUnit.EPOCH)
if self.log_interval.unit not in (TimeUnit.EPOCH, TimeUnit.BATCH):
raise ValueError('The `slack logger log_interval` argument must have units of EPOCH or BATCH.')

Expand Down
5 changes: 5 additions & 0 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2017,6 +2017,11 @@ def _train_loop(self) -> None:
finished_epoch_early = False
last_wct = datetime.datetime.now()

if self.state.max_duration is None:
# This is essentially just a type check, as max_duration should always be
# asserted to be not None when Trainer.fit() is called
raise RuntimeError('max_duration must be specified when initializing the Trainer')

while self.state.timestamp < self.state.max_duration:
if int(self.state.timestamp.batch_in_epoch) == 0:
self.engine.run_event(Event.EPOCH_START)
Expand Down
6 changes: 1 addition & 5 deletions composer/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,7 @@ def create_interval_scheduler(interval: Union[str, int, 'Time'],
if final_events is None:
final_events = {Event.BATCH_CHECKPOINT, Event.EPOCH_CHECKPOINT}

if isinstance(interval, str):
interval = Time.from_timestring(interval)
if isinstance(interval, int):
interval = Time(interval, TimeUnit.EPOCH)

interval = Time.from_input(interval, TimeUnit.EPOCH)
if interval.unit == TimeUnit.EPOCH:
interval_event = Event.EPOCH_CHECKPOINT if checkpoint_events else Event.EPOCH_END
elif interval.unit in {TimeUnit.BATCH, TimeUnit.TOKEN, TimeUnit.SAMPLE, TimeUnit.DURATION}:
Expand Down
58 changes: 58 additions & 0 deletions tests/test_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,64 @@ def test_time_math():
assert t4 * 2 == Time.from_timestring('1dur')
assert t1 / t2 == t4
assert t2 / 2 == t1
assert t3 % t3 == Time.from_timestring('0ep')
assert t3 % t2 == t1


def test_invalid_math():
t1 = Time.from_timestring('1ep')
t2 = Time.from_timestring('1ba')

with pytest.raises(RuntimeError):
_ = t1 > t2

with pytest.raises(RuntimeError):
_ = t1 < t2

with pytest.raises(RuntimeError):
_ = t1 >= t2

with pytest.raises(RuntimeError):
_ = t1 <= t2

with pytest.raises(RuntimeError):
_ = t1 == t2

with pytest.raises(RuntimeError):
_ = t1 != t2

with pytest.raises(RuntimeError):
_ = t1 + t2

with pytest.raises(RuntimeError):
_ = t1 - t2

with pytest.raises(RuntimeError):
_ = t1 / t2

with pytest.raises(RuntimeError):
_ = t1 % t2

with pytest.raises(RuntimeError):
_ = t1 * t2


def test_time_from_input():
expected = Time(1, TimeUnit.EPOCH)

assert Time.from_input(expected) == expected
assert Time.from_input('1ep') == expected
assert Time.from_input(1, TimeUnit.EPOCH) == expected
assert Time.from_input(1, 'ep') == expected

with pytest.raises(RuntimeError):
Time.from_input(None) # type: ignore

with pytest.raises(RuntimeError):
Time.from_input([123]) # type: ignore

with pytest.raises(RuntimeError):
Time.from_input(1)


def test_time_repr():
Expand Down
Loading