Skip to content

Commit f6fe41c

Browse files
authored
Reset loss to zero on logging in Trainer to avoid bfloat16 issues (#8561)
* make tr_loss regular float * Revert "make tr_loss regular float" This reverts commit c9d7ccf. * reset loss at each logging step * keep track of total loss with _total_loss_scalar * add remaining tr_loss at the end
1 parent b592728 commit f6fe41c

File tree

1 file changed

+11
-6
lines changed

1 file changed

+11
-6
lines changed

src/transformers/trainer.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -696,8 +696,10 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D
696696
self.state.is_local_process_zero = self.is_local_process_zero()
697697
self.state.is_world_process_zero = self.is_world_process_zero()
698698

699+
# tr_loss is a tensor to avoid synchronization of TPUs through .item()
699700
tr_loss = torch.tensor(0.0).to(self.args.device)
700-
self._logging_loss_scalar = 0
701+
# _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
702+
self._total_loss_scalar = 0.0
701703
self._globalstep_last_logged = 0
702704
self._total_flos = self.state.total_flos
703705
model.zero_grad()
@@ -812,23 +814,26 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D
812814
self.log({"total_flos": self.state.total_flos})
813815

814816
self.control = self.callback_handler.on_train_end(self.args, self.state, self.control)
817+
# add remaining tr_loss
818+
self._total_loss_scalar += tr_loss.item()
815819

816-
return TrainOutput(self.state.global_step, tr_loss.item() / self.state.global_step)
820+
return TrainOutput(self.state.global_step, self._total_loss_scalar / self.state.global_step)
817821

818822
def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch):
819823
if self.control.should_log:
820824
logs: Dict[str, float] = {}
821825
tr_loss_scalar = tr_loss.item()
822-
logs["loss"] = (tr_loss_scalar - self._logging_loss_scalar) / (
823-
self.state.global_step - self._globalstep_last_logged
824-
)
826+
# reset tr_loss to zero
827+
tr_loss -= tr_loss
828+
829+
logs["loss"] = tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged)
825830
# backward compatibility for pytorch schedulers
826831
logs["learning_rate"] = (
827832
self.lr_scheduler.get_last_lr()[0]
828833
if version.parse(torch.__version__) >= version.parse("1.4")
829834
else self.lr_scheduler.get_lr()[0]
830835
)
831-
self._logging_loss_scalar = tr_loss_scalar
836+
self._total_loss_scalar += tr_loss_scalar
832837
self._globalstep_last_logged = self.state.global_step
833838

834839
self.log(logs)

0 commit comments

Comments
 (0)