forked from mosaicml/composer
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Train loss NaN checking callback (mosaicml#2704)
- Loading branch information
1 parent
6f29ad6
commit eb7a9cf
Showing
4 changed files
with
65 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
# Copyright 2022 MosaicML Composer authors | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
"""Callback for catching loss NaNs.""" | ||
|
||
from typing import Sequence | ||
|
||
import torch | ||
|
||
from composer import Callback, Logger, State | ||
|
||
__all__ = ['NaNMonitor'] | ||
|
||
|
||
class NaNMonitor(Callback): | ||
"""Catches NaNs in the loss and raises an error if one is found.""" | ||
|
||
def after_loss(self, state: State, logger: Logger): | ||
"""Check if loss is NaN and raise an error if so.""" | ||
if isinstance(state.loss, torch.Tensor): | ||
if torch.isnan(state.loss).any(): | ||
raise RuntimeError('Train loss contains a NaN.') | ||
elif isinstance(state.loss, Sequence): | ||
for loss in state.loss: | ||
if torch.isnan(loss).any(): | ||
raise RuntimeError('Train loss contains a NaN.') | ||
else: | ||
raise TypeError(f'Loss is of type {type(state.loss)}, but should be a tensor or a sequence of tensors') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
# Copyright 2022 MosaicML Composer authors | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import pytest | ||
from torch.utils.data import DataLoader | ||
|
||
from composer.callbacks import NaNMonitor | ||
from composer.optim import DecoupledAdamW | ||
from composer.trainer import Trainer | ||
from tests.common import RandomClassificationDataset, SimpleModel | ||
|
||
|
||
@pytest.mark.parametrize('should_nan', [True, False]) | ||
def test_nan_monitor(should_nan): | ||
# Make the callback | ||
nan_monitor = NaNMonitor() | ||
# Test model | ||
model = SimpleModel() | ||
# Construct the trainer and train. Make the LR huge to force a NaN, small if it shouldn't | ||
lr = 1e20 if should_nan else 1e-10 | ||
trainer = Trainer( | ||
model=model, | ||
callbacks=nan_monitor, | ||
train_dataloader=DataLoader(RandomClassificationDataset()), | ||
optimizers=DecoupledAdamW(model.parameters(), lr=lr), | ||
max_duration='100ba', | ||
) | ||
# If it should NaN, expect a RuntimeError | ||
if should_nan: | ||
with pytest.raises(RuntimeError) as excinfo: | ||
trainer.fit() | ||
assert 'Train loss contains a NaN.' in str(excinfo.value) | ||
else: | ||
trainer.fit() |