Skip to content

Commit

Permalink
Train loss NaN checking callback (mosaicml#2704)
Browse files Browse the repository at this point in the history
  • Loading branch information
coryMosaicML authored Nov 9, 2023
1 parent 6f29ad6 commit eb7a9cf
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 0 deletions.
2 changes: 2 additions & 0 deletions composer/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from composer.callbacks.lr_monitor import LRMonitor
from composer.callbacks.memory_monitor import MemoryMonitor
from composer.callbacks.mlperf import MLPerfCallback
from composer.callbacks.nan_monitor import NaNMonitor
from composer.callbacks.optimizer_monitor import OptimizerMonitor
from composer.callbacks.runtime_estimator import RuntimeEstimator
from composer.callbacks.speed_monitor import SpeedMonitor
Expand All @@ -28,6 +29,7 @@
'OptimizerMonitor',
'LRMonitor',
'MemoryMonitor',
'NaNMonitor',
'SpeedMonitor',
'CheckpointSaver',
'MLPerfCallback',
Expand Down
28 changes: 28 additions & 0 deletions composer/callbacks/nan_monitor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""Callback for catching loss NaNs."""

from typing import Sequence

import torch

from composer import Callback, Logger, State

__all__ = ['NaNMonitor']


class NaNMonitor(Callback):
"""Catches NaNs in the loss and raises an error if one is found."""

def after_loss(self, state: State, logger: Logger):
"""Check if loss is NaN and raise an error if so."""
if isinstance(state.loss, torch.Tensor):
if torch.isnan(state.loss).any():
raise RuntimeError('Train loss contains a NaN.')
elif isinstance(state.loss, Sequence):
for loss in state.loss:
if torch.isnan(loss).any():
raise RuntimeError('Train loss contains a NaN.')
else:
raise TypeError(f'Loss is of type {type(state.loss)}, but should be a tensor or a sequence of tensors')
1 change: 1 addition & 0 deletions docs/source/trainer/callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ components of training.
~lr_monitor.LRMonitor
~optimizer_monitor.OptimizerMonitor
~memory_monitor.MemoryMonitor
~nan_monitor.NaNMonitor
~image_visualizer.ImageVisualizer
~mlperf.MLPerfCallback
~threshold_stopper.ThresholdStopper
Expand Down
34 changes: 34 additions & 0 deletions tests/callbacks/test_nan_monitor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

import pytest
from torch.utils.data import DataLoader

from composer.callbacks import NaNMonitor
from composer.optim import DecoupledAdamW
from composer.trainer import Trainer
from tests.common import RandomClassificationDataset, SimpleModel


@pytest.mark.parametrize('should_nan', [True, False])
def test_nan_monitor(should_nan):
# Make the callback
nan_monitor = NaNMonitor()
# Test model
model = SimpleModel()
# Construct the trainer and train. Make the LR huge to force a NaN, small if it shouldn't
lr = 1e20 if should_nan else 1e-10
trainer = Trainer(
model=model,
callbacks=nan_monitor,
train_dataloader=DataLoader(RandomClassificationDataset()),
optimizers=DecoupledAdamW(model.parameters(), lr=lr),
max_duration='100ba',
)
# If it should NaN, expect a RuntimeError
if should_nan:
with pytest.raises(RuntimeError) as excinfo:
trainer.fit()
assert 'Train loss contains a NaN.' in str(excinfo.value)
else:
trainer.fit()

0 comments on commit eb7a9cf

Please sign in to comment.