Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Trainer] Add nan/inf logging filter #13619

Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1297,9 +1297,16 @@ def train(
):
# Avoid unnecessary DDP synchronization since there will be no backward pass on this example.
with model.no_sync():
tr_loss += self.training_step(model, inputs)
tr_loss_step = self.training_step(model, inputs)
else:
tr_loss += self.training_step(model, inputs)
tr_loss_step = self.training_step(model, inputs)

if args.logging_nan_inf_filter and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step)):
# if loss is nan or inf simply add the average of previous logged losses
tr_loss += tr_loss / 1 + (self.state.global_step - self._globalstep_last_logged)
else:
tr_loss += tr_loss_step

self.current_flos += float(self.floating_point_ops(inputs))

# Optimizer step for deepspeed must be called on every step regardless of the value of gradient_accumulation_steps
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,15 @@ class TrainingArguments:
)
logging_first_step: bool = field(default=False, metadata={"help": "Log the first global_step"})
logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
logging_nan_inf_filter: str = field(
default=False,
metadata={
"help": (
"Filter nan and inf losses for logging. "
"Note this flag only affects the logging output and not the optimization step."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to be this long here, but there should be a proper documentation above that can be longer!

)
},
)
save_strategy: IntervalStrategy = field(
default="steps",
metadata={"help": "The checkpoint save strategy to use."},
Expand Down
26 changes: 26 additions & 0 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import dataclasses
import gc
import math
import os
import random
import re
Expand Down Expand Up @@ -528,6 +529,31 @@ def test_number_of_steps_in_training(self):
train_output = trainer.train()
self.assertEqual(train_output.global_step, 10)

def test_logging_inf_nan_filter(self):
config = GPT2Config(vocab_size=100, n_positions=128, n_ctx=128, n_embd=32, n_layer=3, n_head=4)
tiny_gpt2 = GPT2LMHeadModel(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)

# Trainer without inf/nan filter
args = TrainingArguments("./test", learning_rate=1e9, logging_steps=5)
trainer = Trainer(tiny_gpt2, args, train_dataset=train_dataset)
trainer.train()
log_history_no_filter = trainer.state.log_history

# Trainer with inf/nan filter
args = TrainingArguments("./test", learning_rate=1e9, logging_steps=5, logging_nan_inf_filter=True)
trainer = Trainer(tiny_gpt2, args, train_dataset=train_dataset)
trainer.train()
log_history_filter = trainer.state.log_history

def is_any_loss_nan_or_inf(log_history):
losses = [l["loss"] for l in log_history[:-1]]
return any(math.isnan(x) for x in losses) or any(math.isinf(x) for x in losses)

self.assertTrue(is_any_loss_nan_or_inf(log_history_no_filter))
self.assertFalse(is_any_loss_nan_or_inf(log_history_filter))

def test_train_and_eval_dataloaders(self):
n_gpu = max(1, torch.cuda.device_count())
trainer = get_regression_trainer(learning_rate=0.1, per_device_train_batch_size=16)
Expand Down