Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1422,6 +1422,7 @@ class TestDPOVisionTrainer(TrlTestCase):
# ("trl-internal-testing/tiny-PaliGemmaForConditionalGeneration",),
("trl-internal-testing/tiny-LlavaForConditionalGeneration",),
("trl-internal-testing/tiny-LlavaNextForConditionalGeneration",),
("trl-internal-testing/tiny-Gemma3ForConditionalGeneration",),
]
)
def test_vdpo_trainer(self, model_id):
Expand Down
43 changes: 39 additions & 4 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,9 @@ def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> d
if "ref_chosen_logps" in examples[0] and "ref_rejected_logps" in examples[0]:
output["ref_chosen_logps"] = ref_chosen_logps
output["ref_rejected_logps"] = ref_rejected_logps
if "token_type_ids" in examples[0]:
token_type_ids = [torch.tensor(example["token_type_ids"]) for example in examples]
output["token_type_ids"] = pad(token_type_ids, padding_value=0, padding_side="left")

return output

Expand Down Expand Up @@ -790,6 +793,8 @@ def process_row(
output["pixel_attention_mask"] = processed_features["pixel_attention_mask"][0]
if "image_sizes" in processed_features:
output["image_sizes"] = processed_features["image_sizes"][0]
if "token_type_ids" in processed_features:
output["token_type_ids"] = processed_features["token_type_ids"][0]

return output

Expand All @@ -804,6 +809,7 @@ def _set_signature_columns_if_needed(self):
"chosen_input_ids",
"rejected_input_ids",
"image_sizes",
"token_type_ids",
"ref_chosen_logps",
"ref_rejected_logps",
]
Expand Down Expand Up @@ -991,6 +997,8 @@ def concatenated_inputs(
)
if "image_sizes" in batch:
output["image_sizes"] = torch.cat([batch["image_sizes"], batch["image_sizes"]], dim=0)
if "token_type_ids" in batch:
output["token_type_ids"] = torch.cat((batch["token_type_ids"], batch["token_type_ids"]))

# Concatenate the chosen and rejected completions
max_completion_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1])
Expand Down Expand Up @@ -1516,6 +1524,9 @@ def concatenated_forward(
# Concatenate the prompt and completion inputs
input_ids = torch.cat((prompt_input_ids, completion_input_ids), dim=1)
attention_mask = torch.cat((prompt_attention_mask, completion_attention_mask), dim=1)
if "token_type_ids" in concatenated_batch:
prompt_token_type_ids = concatenated_batch["token_type_ids"]
token_type_ids = pad_to_length(prompt_token_type_ids, input_ids.shape[1], 0)
# Mask the prompt but not the completion for the loss
loss_mask = torch.cat(
(torch.zeros_like(prompt_attention_mask), completion_attention_mask),
Expand All @@ -1528,19 +1539,35 @@ def concatenated_forward(
# Flush left to reduce the memory usage
# [[0, 0, x, x, x, x], -> [[x, x, x, x],
# [0, x, x, x, 0, 0]] [x, x, x, 0]]
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
if "token_type_ids" in concatenated_batch:
attention_mask, input_ids, loss_mask, token_type_ids = flush_left(
attention_mask, input_ids, loss_mask, token_type_ids
)
else:
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
attention_mask = attention_mask[:, : self.max_length]
input_ids = input_ids[:, : self.max_length]
loss_mask = loss_mask[:, : self.max_length]
elif self.truncation_mode == "keep_end":
# Flush right before truncating left, then flush left
# [[0, 0, x, x, x, x], -> [[0, 0, x, x],
# [0, x, x, x, 0, 0]] [0, x, x, x]]
attention_mask, input_ids, loss_mask = flush_right(attention_mask, input_ids, loss_mask)
if "token_type_ids" in concatenated_batch:
attention_mask, input_ids, loss_mask, token_type_ids = flush_left(
attention_mask, input_ids, loss_mask, token_type_ids
)
token_type_ids = token_type_ids[:, -self.max_length :]
else:
attention_mask, input_ids, loss_mask = flush_right(attention_mask, input_ids, loss_mask)
input_ids = input_ids[:, -self.max_length :]
attention_mask = attention_mask[:, -self.max_length :]
loss_mask = loss_mask[:, -self.max_length :]
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
if "token_type_ids" in concatenated_batch:
attention_mask, input_ids, loss_mask, token_type_ids = flush_left(
attention_mask, input_ids, loss_mask, token_type_ids
)
else:
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
else:
raise ValueError(
f"Unknown truncation mode: '{self.truncation_mode}'. Should be one of ['keep_end', "
Expand All @@ -1550,7 +1577,15 @@ def concatenated_forward(
# Flush left to reduce the memory usage
# [[0, 0, x, x, x, x], -> [[x, x, x, x],
# [0, x, x, x, 0, 0]] [x, x, x, 0]]
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
if "token_type_ids" in concatenated_batch:
attention_mask, input_ids, loss_mask, token_type_ids = flush_left(
attention_mask, input_ids, loss_mask, token_type_ids
)
else:
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)

if "token_type_ids" in concatenated_batch:
model_kwargs["token_type_ids"] = token_type_ids

if self.use_logits_to_keep:
# Compute logits_to_keep based on loss_mask pattern:
Expand Down
Loading