Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrong tensor index for roll and truncate in DPOTrainer fn concatenated_forward( ). #2330

Closed
1 of 4 tasks
yanghh2000 opened this issue Nov 6, 2024 · 1 comment · Fixed by #2332
Closed
1 of 4 tasks
Labels
🐛 bug Something isn't working 🏋 DPO Related to DPO

Comments

@yanghh2000
Copy link
Contributor

System Info

it is a tensor index error

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

# Get the first column idx that is all zeros and remove every column after that
empty_cols = torch.sum(attention_mask, dim=0) == 0
first_empty_col = torch.nonzero(empty_cols)[0].item() if empty_cols.any() else attention_mask.size(1) + 1
input_ids = input_ids[:, : first_empty_col - 1]
attention_mask = attention_mask[:, : first_empty_col - 1]
loss_mask = loss_mask[:, : first_empty_col - 1]

Expected behavior

The returns of torch.nonzero is the index (starts from 0) of non-zero elements, so there is no need to add -1 to first_empty_col.
The correct code should be:

empty_cols = torch.sum(attention_mask, dim=0) == 0
first_empty_col = torch.nonzero(empty_cols)[0].item() if empty_cols.any() else attention_mask.size(1)
input_ids = input_ids[:, : first_empty_col]
attention_mask = attention_mask[:, : first_empty_col]
loss_mask = loss_mask[:, : first_empty_col]
@qgallouedec
Copy link
Member

Good catch! Thanks! Do you mind opening a PR to fix that?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
🐛 bug Something isn't working 🏋 DPO Related to DPO
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants