Skip to content

Commit

Permalink
remove exception (#2759)
Browse files Browse the repository at this point in the history
  • Loading branch information
mvpatel2000 authored Dec 8, 2023
1 parent 7f55b7a commit cb8f937
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 144 deletions.
3 changes: 1 addition & 2 deletions composer/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from composer.core.serializable import Serializable
from composer.core.state import State
from composer.core.time import Time, Timestamp, TimeUnit, ensure_time
from composer.core.types import JSON, Batch, BreakEpochException, Dataset, MemoryFormat, PyTorchScheduler, TrainerMode
from composer.core.types import JSON, Batch, Dataset, MemoryFormat, PyTorchScheduler, TrainerMode

__all__ = [
'Algorithm',
Expand All @@ -46,6 +46,5 @@
'JSON',
'MemoryFormat',
'TrainerMode',
'BreakEpochException',
'validate_eval_automicrobatching',
]
10 changes: 1 addition & 9 deletions composer/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from composer.utils import StringEnum

__all__ = ['Batch', 'PyTorchScheduler', 'JSON', 'MemoryFormat', 'TrainerMode', 'BreakEpochException']
__all__ = ['Batch', 'PyTorchScheduler', 'JSON', 'MemoryFormat', 'TrainerMode']

Batch = Any

Expand All @@ -37,14 +37,6 @@
JSON = Union[str, float, int, None, List['JSON'], Dict[str, 'JSON']]


class BreakEpochException(Exception):
"""Raising this exception will immediately end the current epoch.
If you're wondering whether you should use this, the answer is no.
"""
pass


class TrainerMode(StringEnum):
"""Enum to represent which mode the Trainer is in.
Expand Down
262 changes: 129 additions & 133 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,9 @@
from torchmetrics import Metric

from composer.callbacks import CheckpointSaver, OptimizerMonitor
from composer.core import (Algorithm, AlgorithmPass, Batch, BreakEpochException, Callback, DataSpec, Engine, Evaluator,
Event, Precision, PyTorchScheduler, State, Time, Timestamp, TimeUnit, TrainerMode,
ensure_data_spec, ensure_evaluator, ensure_time, get_precision_context,
validate_eval_automicrobatching)
from composer.core import (Algorithm, AlgorithmPass, Batch, Callback, DataSpec, Engine, Evaluator, Event, Precision,
PyTorchScheduler, State, Time, Timestamp, TimeUnit, TrainerMode, ensure_data_spec,
ensure_evaluator, ensure_time, get_precision_context, validate_eval_automicrobatching)
from composer.devices import Device, DeviceCPU, DeviceGPU, DeviceMPS, DeviceTPU
from composer.loggers import (ConsoleLogger, Logger, LoggerDestination, MosaicMLLogger, ProgressBarLogger,
RemoteUploaderDownloader, WandBLogger)
Expand Down Expand Up @@ -2019,143 +2018,140 @@ def _train_loop(self) -> None:
last_wct = datetime.datetime.now()

while self.state.timestamp < self.state.max_duration:
try:
if int(self.state.timestamp.batch_in_epoch) == 0:
self.engine.run_event(Event.EPOCH_START)
self.logger.log_metrics({'time/epoch': self.state.timestamp.epoch.value})
if int(self.state.timestamp.batch_in_epoch) == 0:
self.engine.run_event(Event.EPOCH_START)
self.logger.log_metrics({'time/epoch': self.state.timestamp.epoch.value})

dataloader = self.state.dataloader
if isinstance(dataloader, DataLoader) and isinstance(dataloader.sampler, DistributedSampler):
dataloader.sampler.set_epoch(int(self.state.timestamp.epoch))

for batch_idx, self.state.batch in enumerate(self._iter_dataloader(TrainerMode.TRAIN)):
# Spin dataloader forward unless dataloader handles internally with dataset_resumption
if self.spin_dataloaders and 'train' not in self.state.dataset_resumption and batch_idx < int(
self.state.timestamp.batch_in_epoch):
# Restore the RNG state immediately before the next batch is yielded from the dataloader
if batch_idx + 1 == int(self.state.timestamp.batch_in_epoch) and self._rng_state is not None:
reproducibility.load_rng_state(self._rng_state)
self._rng_state = None
continue

self.state.batch = self.state.device.batch_to_device(self.state.batch)
self.state.batch = self._train_data_spec.device_transforms(self.state.batch)
rank_num_samples = self._train_data_spec.get_num_samples_in_batch(self.state.batch)
rank_num_tokens = self._train_data_spec.get_num_tokens_in_batch(self.state.batch)

if self.state.deepspeed_enabled:
self.state.batch = _fix_batch_precision_for_deepspeed(self.state.batch, self.state.precision)

self.engine.run_event(Event.AFTER_DATALOADER)

self.engine.run_event(Event.BATCH_START)

# Log time values
self.logger.log_metrics({
'time/batch': self.state.timestamp.batch.value,
'time/sample': self.state.timestamp.sample.value,
'time/batch_in_epoch': self.state.timestamp.batch_in_epoch.value,
'time/sample_in_epoch': self.state.timestamp.sample_in_epoch.value,
})
if rank_num_tokens > 0:
self.logger.log_metrics({'time/token': self.state.timestamp.token.value})
self.logger.log_metrics({'time/token_in_epoch': self.state.timestamp.token_in_epoch.value})

total_loss_dict = self._train_batch(use_grad_scaling)

if use_grad_scaling:
self.state.scaler.update()

# total_loss_dict can be None if gradient scaling failed
if total_loss_dict is not None:
map_collection(total_loss_dict, dist.all_reduce)
total_loss_dict = {
k: loss.cpu().item() / dist.get_world_size() for k, loss in total_loss_dict.items()
}
self.state.total_loss_dict = total_loss_dict
self.logger.log_metrics(total_loss_dict)

# The scheduler step.step() and compute_and_log_metrics() are going to be included in the
# next batch's wall clock time. The time accumulation must be done here so schedulers
# have the latest timing information

now = datetime.datetime.now()

batch_time = now - last_wct

total_num_samples, total_num_tokens, batch_time = self._accumulate_time_across_ranks(
rank_num_samples,
rank_num_tokens,
batch_time,
dataloader = self.state.dataloader
if isinstance(dataloader, DataLoader) and isinstance(dataloader.sampler, DistributedSampler):
dataloader.sampler.set_epoch(int(self.state.timestamp.epoch))

for batch_idx, self.state.batch in enumerate(self._iter_dataloader(TrainerMode.TRAIN)):
# Spin dataloader forward unless dataloader handles internally with dataset_resumption
if self.spin_dataloaders and 'train' not in self.state.dataset_resumption and batch_idx < int(
self.state.timestamp.batch_in_epoch):
# Restore the RNG state immediately before the next batch is yielded from the dataloader
if batch_idx + 1 == int(self.state.timestamp.batch_in_epoch) and self._rng_state is not None:
reproducibility.load_rng_state(self._rng_state)
self._rng_state = None
continue

self.state.batch = self.state.device.batch_to_device(self.state.batch)
self.state.batch = self._train_data_spec.device_transforms(self.state.batch)
rank_num_samples = self._train_data_spec.get_num_samples_in_batch(self.state.batch)
rank_num_tokens = self._train_data_spec.get_num_tokens_in_batch(self.state.batch)

if self.state.deepspeed_enabled:
self.state.batch = _fix_batch_precision_for_deepspeed(self.state.batch, self.state.precision)

self.engine.run_event(Event.AFTER_DATALOADER)

self.engine.run_event(Event.BATCH_START)

# Log time values
self.logger.log_metrics({
'time/batch': self.state.timestamp.batch.value,
'time/sample': self.state.timestamp.sample.value,
'time/batch_in_epoch': self.state.timestamp.batch_in_epoch.value,
'time/sample_in_epoch': self.state.timestamp.sample_in_epoch.value,
})
if rank_num_tokens > 0:
self.logger.log_metrics({'time/token': self.state.timestamp.token.value})
self.logger.log_metrics({'time/token_in_epoch': self.state.timestamp.token_in_epoch.value})

total_loss_dict = self._train_batch(use_grad_scaling)

if use_grad_scaling:
self.state.scaler.update()

# total_loss_dict can be None if gradient scaling failed
if total_loss_dict is not None:
map_collection(total_loss_dict, dist.all_reduce)
total_loss_dict = {
k: loss.cpu().item() / dist.get_world_size() for k, loss in total_loss_dict.items()
}
self.state.total_loss_dict = total_loss_dict
self.logger.log_metrics(total_loss_dict)

# The scheduler step.step() and compute_and_log_metrics() are going to be included in the
# next batch's wall clock time. The time accumulation must be done here so schedulers
# have the latest timing information

now = datetime.datetime.now()

batch_time = now - last_wct

total_num_samples, total_num_tokens, batch_time = self._accumulate_time_across_ranks(
rank_num_samples,
rank_num_tokens,
batch_time,
)

# `now` is actually in the past, but want to include the time it takes to perform this reduction
last_wct = now

if self._scheduler_step_frequency == TimeUnit.BATCH:
for scheduler in self.state.schedulers:
scheduler.step()

if self.state.train_metrics is not None:
self._compute_and_log_metrics(
dataloader_label='train',
metrics=self.state.train_metrics,
)

# `now` is actually in the past, but want to include the time it takes to perform this reduction
last_wct = now
self.state.previous_timestamp = self.state.timestamp
self.state.timestamp = self.state.timestamp.to_next_batch(
samples=total_num_samples,
tokens=total_num_tokens,
duration=batch_time,
)

self.engine.run_event(Event.BATCH_END)

if self._scheduler_step_frequency == TimeUnit.BATCH:
for scheduler in self.state.schedulers:
scheduler.step()
# Pause the timing during evaluation
# Evaluation time is tracked separately in state.eval_timestamp
duration = datetime.datetime.now() - last_wct
self._run_evaluators(Event.BATCH_END)
last_wct = datetime.datetime.now() - duration

if self.state.train_metrics is not None:
self._compute_and_log_metrics(
dataloader_label='train',
metrics=self.state.train_metrics,
)
self.engine.run_event(Event.BATCH_CHECKPOINT)

if self.state.timestamp >= self.state.max_duration:
# If max_duration is specified in batches, samples, or tokens, and
# and the max_duration is reached mid-epoch, then break out of the dataloader
# to finish the epoch early and finish training.
finished_epoch_early = True
break

self.state.previous_timestamp = self.state.timestamp
self.state.timestamp = self.state.timestamp.to_next_batch(
samples=total_num_samples,
tokens=total_num_tokens,
duration=batch_time,
if not finished_epoch_early or self.state.dataloader_len == self.state.timestamp.batch_in_epoch:
# Trigger the epoch end events if the dataloader was exhausted.
# This happens if the "break" did not trigger above, or if it
# did (e.g. duration specified in samples/batches/tokens), but it is still
# the end of the dataloader (i.e. next(dataloader) would raise StopIteration)
if self.state.train_metrics is not None:
self.state.train_metrics = self._ensure_metrics_device_and_dtype(self.state.train_metrics)
self._compute_and_log_metrics(
dataloader_label='train',
metrics=self.state.train_metrics,
)

self.engine.run_event(Event.BATCH_END)

# Pause the timing during evaluation
# Evaluation time is tracked separately in state.eval_timestamp
duration = datetime.datetime.now() - last_wct
self._run_evaluators(Event.BATCH_END)
last_wct = datetime.datetime.now() - duration

self.engine.run_event(Event.BATCH_CHECKPOINT)

if self.state.timestamp >= self.state.max_duration:
# If max_duration is specified in batches, samples, or tokens, and
# and the max_duration is reached mid-epoch, then break out of the dataloader
# to finish the epoch early and finish training.
finished_epoch_early = True
break

if not finished_epoch_early or self.state.dataloader_len == self.state.timestamp.batch_in_epoch:
# Trigger the epoch end events if the dataloader was exhausted.
# This happens if the "break" did not trigger above, or if it
# did (e.g. duration specified in samples/batches/tokens), but it is still
# the end of the dataloader (i.e. next(dataloader) would raise StopIteration)
if self.state.train_metrics is not None:
self.state.train_metrics = self._ensure_metrics_device_and_dtype(self.state.train_metrics)
self._compute_and_log_metrics(
dataloader_label='train',
metrics=self.state.train_metrics,
)

if self._scheduler_step_frequency == TimeUnit.EPOCH:
for scheduler in self.state.schedulers:
scheduler.step()

self.state.previous_timestamp = self.state.timestamp
self.state.timestamp = self.state.timestamp.to_next_epoch()

self.engine.run_event(Event.EPOCH_END)

# Pause the timing during evaluation
# Evaluation time is tracked separately in state.eval_timestamp
duration = datetime.datetime.now() - last_wct
self._run_evaluators(Event.EPOCH_END)
last_wct = datetime.datetime.now() - duration

self.engine.run_event(Event.EPOCH_CHECKPOINT)
except BreakEpochException:
log.info(f'Skipping the rest of Epoch {int(self.state.timestamp.epoch)}')
if self._scheduler_step_frequency == TimeUnit.EPOCH:
for scheduler in self.state.schedulers:
scheduler.step()

self.state.previous_timestamp = self.state.timestamp
self.state.timestamp = self.state.timestamp.to_next_epoch()

self.engine.run_event(Event.EPOCH_END)

# Pause the timing during evaluation
# Evaluation time is tracked separately in state.eval_timestamp
duration = datetime.datetime.now() - last_wct
self._run_evaluators(Event.EPOCH_END)
last_wct = datetime.datetime.now() - duration

self.engine.run_event(Event.EPOCH_CHECKPOINT)

# Log final time values
self.logger.log_metrics({
Expand Down

0 comments on commit cb8f937

Please sign in to comment.