diff --git a/composer/core/__init__.py b/composer/core/__init__.py index 99e02ed861..748778b460 100644 --- a/composer/core/__init__.py +++ b/composer/core/__init__.py @@ -19,7 +19,7 @@ from composer.core.serializable import Serializable from composer.core.state import State from composer.core.time import Time, Timestamp, TimeUnit, ensure_time -from composer.core.types import JSON, Batch, BreakEpochException, Dataset, MemoryFormat, PyTorchScheduler, TrainerMode +from composer.core.types import JSON, Batch, Dataset, MemoryFormat, PyTorchScheduler, TrainerMode __all__ = [ 'Algorithm', @@ -46,6 +46,5 @@ 'JSON', 'MemoryFormat', 'TrainerMode', - 'BreakEpochException', 'validate_eval_automicrobatching', ] diff --git a/composer/core/types.py b/composer/core/types.py index e1a6e5b37c..2ffa343d4d 100644 --- a/composer/core/types.py +++ b/composer/core/types.py @@ -21,7 +21,7 @@ from composer.utils import StringEnum -__all__ = ['Batch', 'PyTorchScheduler', 'JSON', 'MemoryFormat', 'TrainerMode', 'BreakEpochException'] +__all__ = ['Batch', 'PyTorchScheduler', 'JSON', 'MemoryFormat', 'TrainerMode'] Batch = Any @@ -37,14 +37,6 @@ JSON = Union[str, float, int, None, List['JSON'], Dict[str, 'JSON']] -class BreakEpochException(Exception): - """Raising this exception will immediately end the current epoch. - - If you're wondering whether you should use this, the answer is no. - """ - pass - - class TrainerMode(StringEnum): """Enum to represent which mode the Trainer is in. diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 8e17a1b87b..8d364eec6c 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -35,10 +35,9 @@ from torchmetrics import Metric from composer.callbacks import CheckpointSaver, OptimizerMonitor -from composer.core import (Algorithm, AlgorithmPass, Batch, BreakEpochException, Callback, DataSpec, Engine, Evaluator, - Event, Precision, PyTorchScheduler, State, Time, Timestamp, TimeUnit, TrainerMode, - ensure_data_spec, ensure_evaluator, ensure_time, get_precision_context, - validate_eval_automicrobatching) +from composer.core import (Algorithm, AlgorithmPass, Batch, Callback, DataSpec, Engine, Evaluator, Event, Precision, + PyTorchScheduler, State, Time, Timestamp, TimeUnit, TrainerMode, ensure_data_spec, + ensure_evaluator, ensure_time, get_precision_context, validate_eval_automicrobatching) from composer.devices import Device, DeviceCPU, DeviceGPU, DeviceMPS, DeviceTPU from composer.loggers import (ConsoleLogger, Logger, LoggerDestination, MosaicMLLogger, ProgressBarLogger, RemoteUploaderDownloader, WandBLogger) @@ -2019,143 +2018,140 @@ def _train_loop(self) -> None: last_wct = datetime.datetime.now() while self.state.timestamp < self.state.max_duration: - try: - if int(self.state.timestamp.batch_in_epoch) == 0: - self.engine.run_event(Event.EPOCH_START) - self.logger.log_metrics({'time/epoch': self.state.timestamp.epoch.value}) + if int(self.state.timestamp.batch_in_epoch) == 0: + self.engine.run_event(Event.EPOCH_START) + self.logger.log_metrics({'time/epoch': self.state.timestamp.epoch.value}) - dataloader = self.state.dataloader - if isinstance(dataloader, DataLoader) and isinstance(dataloader.sampler, DistributedSampler): - dataloader.sampler.set_epoch(int(self.state.timestamp.epoch)) - - for batch_idx, self.state.batch in enumerate(self._iter_dataloader(TrainerMode.TRAIN)): - # Spin dataloader forward unless dataloader handles internally with dataset_resumption - if self.spin_dataloaders and 'train' not in self.state.dataset_resumption and batch_idx < int( - self.state.timestamp.batch_in_epoch): - # Restore the RNG state immediately before the next batch is yielded from the dataloader - if batch_idx + 1 == int(self.state.timestamp.batch_in_epoch) and self._rng_state is not None: - reproducibility.load_rng_state(self._rng_state) - self._rng_state = None - continue - - self.state.batch = self.state.device.batch_to_device(self.state.batch) - self.state.batch = self._train_data_spec.device_transforms(self.state.batch) - rank_num_samples = self._train_data_spec.get_num_samples_in_batch(self.state.batch) - rank_num_tokens = self._train_data_spec.get_num_tokens_in_batch(self.state.batch) - - if self.state.deepspeed_enabled: - self.state.batch = _fix_batch_precision_for_deepspeed(self.state.batch, self.state.precision) - - self.engine.run_event(Event.AFTER_DATALOADER) - - self.engine.run_event(Event.BATCH_START) - - # Log time values - self.logger.log_metrics({ - 'time/batch': self.state.timestamp.batch.value, - 'time/sample': self.state.timestamp.sample.value, - 'time/batch_in_epoch': self.state.timestamp.batch_in_epoch.value, - 'time/sample_in_epoch': self.state.timestamp.sample_in_epoch.value, - }) - if rank_num_tokens > 0: - self.logger.log_metrics({'time/token': self.state.timestamp.token.value}) - self.logger.log_metrics({'time/token_in_epoch': self.state.timestamp.token_in_epoch.value}) - - total_loss_dict = self._train_batch(use_grad_scaling) - - if use_grad_scaling: - self.state.scaler.update() - - # total_loss_dict can be None if gradient scaling failed - if total_loss_dict is not None: - map_collection(total_loss_dict, dist.all_reduce) - total_loss_dict = { - k: loss.cpu().item() / dist.get_world_size() for k, loss in total_loss_dict.items() - } - self.state.total_loss_dict = total_loss_dict - self.logger.log_metrics(total_loss_dict) - - # The scheduler step.step() and compute_and_log_metrics() are going to be included in the - # next batch's wall clock time. The time accumulation must be done here so schedulers - # have the latest timing information - - now = datetime.datetime.now() - - batch_time = now - last_wct - - total_num_samples, total_num_tokens, batch_time = self._accumulate_time_across_ranks( - rank_num_samples, - rank_num_tokens, - batch_time, + dataloader = self.state.dataloader + if isinstance(dataloader, DataLoader) and isinstance(dataloader.sampler, DistributedSampler): + dataloader.sampler.set_epoch(int(self.state.timestamp.epoch)) + + for batch_idx, self.state.batch in enumerate(self._iter_dataloader(TrainerMode.TRAIN)): + # Spin dataloader forward unless dataloader handles internally with dataset_resumption + if self.spin_dataloaders and 'train' not in self.state.dataset_resumption and batch_idx < int( + self.state.timestamp.batch_in_epoch): + # Restore the RNG state immediately before the next batch is yielded from the dataloader + if batch_idx + 1 == int(self.state.timestamp.batch_in_epoch) and self._rng_state is not None: + reproducibility.load_rng_state(self._rng_state) + self._rng_state = None + continue + + self.state.batch = self.state.device.batch_to_device(self.state.batch) + self.state.batch = self._train_data_spec.device_transforms(self.state.batch) + rank_num_samples = self._train_data_spec.get_num_samples_in_batch(self.state.batch) + rank_num_tokens = self._train_data_spec.get_num_tokens_in_batch(self.state.batch) + + if self.state.deepspeed_enabled: + self.state.batch = _fix_batch_precision_for_deepspeed(self.state.batch, self.state.precision) + + self.engine.run_event(Event.AFTER_DATALOADER) + + self.engine.run_event(Event.BATCH_START) + + # Log time values + self.logger.log_metrics({ + 'time/batch': self.state.timestamp.batch.value, + 'time/sample': self.state.timestamp.sample.value, + 'time/batch_in_epoch': self.state.timestamp.batch_in_epoch.value, + 'time/sample_in_epoch': self.state.timestamp.sample_in_epoch.value, + }) + if rank_num_tokens > 0: + self.logger.log_metrics({'time/token': self.state.timestamp.token.value}) + self.logger.log_metrics({'time/token_in_epoch': self.state.timestamp.token_in_epoch.value}) + + total_loss_dict = self._train_batch(use_grad_scaling) + + if use_grad_scaling: + self.state.scaler.update() + + # total_loss_dict can be None if gradient scaling failed + if total_loss_dict is not None: + map_collection(total_loss_dict, dist.all_reduce) + total_loss_dict = { + k: loss.cpu().item() / dist.get_world_size() for k, loss in total_loss_dict.items() + } + self.state.total_loss_dict = total_loss_dict + self.logger.log_metrics(total_loss_dict) + + # The scheduler step.step() and compute_and_log_metrics() are going to be included in the + # next batch's wall clock time. The time accumulation must be done here so schedulers + # have the latest timing information + + now = datetime.datetime.now() + + batch_time = now - last_wct + + total_num_samples, total_num_tokens, batch_time = self._accumulate_time_across_ranks( + rank_num_samples, + rank_num_tokens, + batch_time, + ) + + # `now` is actually in the past, but want to include the time it takes to perform this reduction + last_wct = now + + if self._scheduler_step_frequency == TimeUnit.BATCH: + for scheduler in self.state.schedulers: + scheduler.step() + + if self.state.train_metrics is not None: + self._compute_and_log_metrics( + dataloader_label='train', + metrics=self.state.train_metrics, ) - # `now` is actually in the past, but want to include the time it takes to perform this reduction - last_wct = now + self.state.previous_timestamp = self.state.timestamp + self.state.timestamp = self.state.timestamp.to_next_batch( + samples=total_num_samples, + tokens=total_num_tokens, + duration=batch_time, + ) + + self.engine.run_event(Event.BATCH_END) - if self._scheduler_step_frequency == TimeUnit.BATCH: - for scheduler in self.state.schedulers: - scheduler.step() + # Pause the timing during evaluation + # Evaluation time is tracked separately in state.eval_timestamp + duration = datetime.datetime.now() - last_wct + self._run_evaluators(Event.BATCH_END) + last_wct = datetime.datetime.now() - duration - if self.state.train_metrics is not None: - self._compute_and_log_metrics( - dataloader_label='train', - metrics=self.state.train_metrics, - ) + self.engine.run_event(Event.BATCH_CHECKPOINT) + + if self.state.timestamp >= self.state.max_duration: + # If max_duration is specified in batches, samples, or tokens, and + # and the max_duration is reached mid-epoch, then break out of the dataloader + # to finish the epoch early and finish training. + finished_epoch_early = True + break - self.state.previous_timestamp = self.state.timestamp - self.state.timestamp = self.state.timestamp.to_next_batch( - samples=total_num_samples, - tokens=total_num_tokens, - duration=batch_time, + if not finished_epoch_early or self.state.dataloader_len == self.state.timestamp.batch_in_epoch: + # Trigger the epoch end events if the dataloader was exhausted. + # This happens if the "break" did not trigger above, or if it + # did (e.g. duration specified in samples/batches/tokens), but it is still + # the end of the dataloader (i.e. next(dataloader) would raise StopIteration) + if self.state.train_metrics is not None: + self.state.train_metrics = self._ensure_metrics_device_and_dtype(self.state.train_metrics) + self._compute_and_log_metrics( + dataloader_label='train', + metrics=self.state.train_metrics, ) - self.engine.run_event(Event.BATCH_END) - - # Pause the timing during evaluation - # Evaluation time is tracked separately in state.eval_timestamp - duration = datetime.datetime.now() - last_wct - self._run_evaluators(Event.BATCH_END) - last_wct = datetime.datetime.now() - duration - - self.engine.run_event(Event.BATCH_CHECKPOINT) - - if self.state.timestamp >= self.state.max_duration: - # If max_duration is specified in batches, samples, or tokens, and - # and the max_duration is reached mid-epoch, then break out of the dataloader - # to finish the epoch early and finish training. - finished_epoch_early = True - break - - if not finished_epoch_early or self.state.dataloader_len == self.state.timestamp.batch_in_epoch: - # Trigger the epoch end events if the dataloader was exhausted. - # This happens if the "break" did not trigger above, or if it - # did (e.g. duration specified in samples/batches/tokens), but it is still - # the end of the dataloader (i.e. next(dataloader) would raise StopIteration) - if self.state.train_metrics is not None: - self.state.train_metrics = self._ensure_metrics_device_and_dtype(self.state.train_metrics) - self._compute_and_log_metrics( - dataloader_label='train', - metrics=self.state.train_metrics, - ) - - if self._scheduler_step_frequency == TimeUnit.EPOCH: - for scheduler in self.state.schedulers: - scheduler.step() - - self.state.previous_timestamp = self.state.timestamp - self.state.timestamp = self.state.timestamp.to_next_epoch() - - self.engine.run_event(Event.EPOCH_END) - - # Pause the timing during evaluation - # Evaluation time is tracked separately in state.eval_timestamp - duration = datetime.datetime.now() - last_wct - self._run_evaluators(Event.EPOCH_END) - last_wct = datetime.datetime.now() - duration - - self.engine.run_event(Event.EPOCH_CHECKPOINT) - except BreakEpochException: - log.info(f'Skipping the rest of Epoch {int(self.state.timestamp.epoch)}') + if self._scheduler_step_frequency == TimeUnit.EPOCH: + for scheduler in self.state.schedulers: + scheduler.step() + + self.state.previous_timestamp = self.state.timestamp + self.state.timestamp = self.state.timestamp.to_next_epoch() + + self.engine.run_event(Event.EPOCH_END) + + # Pause the timing during evaluation + # Evaluation time is tracked separately in state.eval_timestamp + duration = datetime.datetime.now() - last_wct + self._run_evaluators(Event.EPOCH_END) + last_wct = datetime.datetime.now() - duration + + self.engine.run_event(Event.EPOCH_CHECKPOINT) # Log final time values self.logger.log_metrics({