Skip to content

Commit

Permalink
🌋 Add support for LLaVA-Next in DPOTrainer (huggingface#2413)
Browse files Browse the repository at this point in the history
* add support for llava-next in dpotrainer

* enable unit test

* code style

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Ignore last layer in test

---------

Co-authored-by: zesong.cwz <zesong.cwz@taobao.com>
Co-authored-by: 1rubbishyuan <2773496952@qq.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
  • Loading branch information
5 people authored Nov 29, 2024
1 parent 94e4135 commit 8d9cfaa
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 3 deletions.
7 changes: 5 additions & 2 deletions tests/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1139,7 +1139,7 @@ class DPOVisionTrainerTester(unittest.TestCase):
("trl-internal-testing/tiny-Idefics2ForConditionalGeneration",),
# ("trl-internal-testing/tiny-PaliGemmaForConditionalGeneration",),
("trl-internal-testing/tiny-LlavaForConditionalGeneration",),
# ("trl-internal-testing/tiny-LlavaNextForConditionalGeneration",),
("trl-internal-testing/tiny-LlavaNextForConditionalGeneration",),
]
)
def test_vdpo_trainer(self, model_id):
Expand Down Expand Up @@ -1211,7 +1211,10 @@ def test_vdpo_trainer(self, model_id):
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if param.sum() != 0: # ignore 0 biases
if model_id == "trl-internal-testing/tiny-LlavaForConditionalGeneration" and (
if model_id in [
"trl-internal-testing/tiny-LlavaForConditionalGeneration",
"trl-internal-testing/tiny-LlavaNextForConditionalGeneration",
] and (
n.startswith("vision_tower.vision_model.encoder.layers.1")
or n == "vision_tower.vision_model.post_layernorm.weight"
):
Expand Down
10 changes: 9 additions & 1 deletion trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> d
output["pixel_values"] = pad(pixel_values, padding_value=0.0)
if "pixel_attention_mask" in examples[0]:
output["pixel_attention_mask"] = pad(pixel_attention_mask, padding_value=0)
if "image_sizes" in examples[0]:
output["image_sizes"] = torch.tensor([example["image_sizes"] for example in examples])

return output

Expand Down Expand Up @@ -645,6 +647,8 @@ def process_row(features, processing_class, max_prompt_length, max_completion_le

if "pixel_attention_mask" in processed_features:
output["pixel_attention_mask"] = processed_features["pixel_attention_mask"][0]
if "image_sizes" in processed_features:
output["image_sizes"] = processed_features["image_sizes"][0]

return output

Expand Down Expand Up @@ -685,7 +689,7 @@ def _set_signature_columns_if_needed(self):
# In DPOTrainer, we preprocess data, so using the model's signature columns doesn't work.
# Instead, we set them to the columns expected by `DPODataCollatorWithPadding`, hence the override.
if self._signature_columns is None:
self._signature_columns = ["prompt_input_ids", "chosen_input_ids", "rejected_input_ids"]
self._signature_columns = ["prompt_input_ids", "chosen_input_ids", "rejected_input_ids", "image_sizes"]

def get_train_dataloader(self) -> DataLoader:
"""
Expand Down Expand Up @@ -855,6 +859,8 @@ def concatenated_inputs(
output["pixel_attention_mask"] = torch.cat(
[batch["pixel_attention_mask"], batch["pixel_attention_mask"]], dim=0
)
if "image_sizes" in batch:
output["image_sizes"] = torch.cat([batch["image_sizes"], batch["image_sizes"]], dim=0)

# Concatenate the chosen and rejected completions
max_completion_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1])
Expand Down Expand Up @@ -1078,6 +1084,8 @@ def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, to
model_kwargs["pixel_values"] = concatenated_batch["pixel_values"]
if "pixel_attention_mask" in concatenated_batch:
model_kwargs["pixel_attention_mask"] = concatenated_batch["pixel_attention_mask"]
if "image_sizes" in concatenated_batch:
model_kwargs["image_sizes"] = concatenated_batch["image_sizes"]

prompt_input_ids = concatenated_batch["prompt_input_ids"]
prompt_attention_mask = concatenated_batch["prompt_attention_mask"]
Expand Down

0 comments on commit 8d9cfaa

Please sign in to comment.