diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index c96ce058df..eccc9496c2 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -1401,7 +1401,7 @@ def cross_entropy_loss(logits, labels): logits = logits[..., :-1, :].contiguous() labels = labels[..., 1:].contiguous() # Flatten the tokens - loss_fct = nn.CrossEntropyLoss() + loss_fct = nn.CrossEntropyLoss(ignore_index=self.label_pad_token_id) logits = logits.view(-1, logits.shape[-1]) labels = labels.view(-1) # Enable model parallelism