Skip to content

Commit

Permalink
Add ignore_index in DPOTrainer's nn.CrossEntropyLoss (#1987)
Browse files Browse the repository at this point in the history
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
  • Loading branch information
akakakakakaa and kashif authored Aug 28, 2024
1 parent 47ab034 commit 10f70fa
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1401,7 +1401,7 @@ def cross_entropy_loss(logits, labels):
logits = logits[..., :-1, :].contiguous()
labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
loss_fct = nn.CrossEntropyLoss(ignore_index=self.label_pad_token_id)
logits = logits.view(-1, logits.shape[-1])
labels = labels.view(-1)
# Enable model parallelism
Expand Down

0 comments on commit 10f70fa

Please sign in to comment.