diff --git a/docs/source/dpo_trainer.mdx b/docs/source/dpo_trainer.mdx index a7c8fd8c36..feb2e2adf2 100644 --- a/docs/source/dpo_trainer.mdx +++ b/docs/source/dpo_trainer.mdx @@ -70,8 +70,33 @@ dpo_dataset_dict = { where the `prompt` contains the context inputs, `chosen` contains the corresponding chosen responses and `rejected` contains the corresponding negative (rejected) responses. As can be seen a prompt can have multiple responses and this is reflected in the entries being repeated in the dictionary's value arrays. +`DPOTrainer` can be used to fine-tune visual language models (VLMs). In this case, the dataset must also contain the key `images`, and the trainer's `tokenizer` is the VLM's `processor`. For example, for Idefics2, the processor expects the dataset to have the following format: + +Note: Currently, VLM support is exclusive to Idefics2 and does not extend to other VLMs. + +```py +dpo_dataset_dict = { + 'images': [ + [Image.open('beach.jpg')], + [Image.open('street.jpg')], + ], + 'prompt': [ + 'The image shows', + ' The image depicts', + ], + 'chosen': [ + 'a sunny beach with palm trees.', + 'a busy street with several cars and buildings.', + ], + 'rejected': [ + 'a snowy mountain with skiers.', + 'a calm countryside with green fields.', + ], +} +``` + ## Expected model format -The DPO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that expects `AutoModelForCausalLMWithValueHead` for the value function. +The DPO trainer expects a model of `AutoModelForCausalLM` or `AutoModelForVision2Seq`, compared to PPO that expects `AutoModelForCausalLMWithValueHead` for the value function. ## Using the `DPOTrainer` @@ -86,7 +111,7 @@ dpo_trainer = DPOTrainer( model_ref, args=training_args, train_dataset=train_dataset, - tokenizer=tokenizer, + tokenizer=tokenizer, # for visual language models, use tokenizer=processor instead ) ``` After this one can then call: diff --git a/examples/scripts/dpo_visual.py b/examples/scripts/dpo_visual.py new file mode 100644 index 0000000000..b602c8f009 --- /dev/null +++ b/examples/scripts/dpo_visual.py @@ -0,0 +1,177 @@ +# flake8: noqa +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +accelerate launch examples/scripts/vdpo.py \ + --dataset_name HuggingFaceH4/rlaif-v_formatted \ + --model_name_or_path HuggingFaceM4/idefics2-8b \ + --per_device_train_batch_size 1 \ + --gradient_accumulation_steps 16 \ + --dataset_num_proc 32 \ + --output_dir dpo_idefics_rlaif-v \ + --bf16 \ + --torch_dtype bfloat16 \ + --use_peft \ + --lora_target_modules=all-linear +""" + +import logging +import os +from contextlib import nullcontext + +TRL_USE_RICH = os.environ.get("TRL_USE_RICH", False) + +from trl.commands.cli_utils import DPOScriptArguments, init_zero_verbose, TrlParser +from accelerate import PartialState + +if TRL_USE_RICH: + init_zero_verbose() + FORMAT = "%(message)s" + + from rich.console import Console + from rich.logging import RichHandler + +import torch +from datasets import load_dataset +from transformers import AutoModelForVision2Seq, AutoProcessor + +from trl import ( + DPOConfig, + DPOTrainer, + ModelConfig, + RichProgressCallback, + get_kbit_device_map, + get_peft_config, + get_quantization_config, +) + + +if TRL_USE_RICH: + logging.basicConfig(format=FORMAT, datefmt="[%X]", handlers=[RichHandler()], level=logging.INFO) + + +if __name__ == "__main__": + parser = TrlParser((DPOScriptArguments, DPOConfig, ModelConfig)) + args, training_args, model_config = parser.parse_args_and_config() + + # Force use our print callback + if TRL_USE_RICH: + training_args.disable_tqdm = True + console = Console() + + ################ + # Model & Tokenizer + ################ + torch_dtype = ( + model_config.torch_dtype + if model_config.torch_dtype in ["auto", None] + else getattr(torch, model_config.torch_dtype) + ) + quantization_config = get_quantization_config(model_config) + + model_kwargs = dict( + revision=model_config.model_revision, + trust_remote_code=model_config.trust_remote_code, + attn_implementation=model_config.attn_implementation, + torch_dtype=torch_dtype, + use_cache=False if training_args.gradient_checkpointing else True, + device_map=get_kbit_device_map() if quantization_config is not None else None, + quantization_config=quantization_config, + ) + model = AutoModelForVision2Seq.from_pretrained(model_config.model_name_or_path, **model_kwargs) + peft_config = get_peft_config(model_config) + if peft_config is None: + model_ref = AutoModelForVision2Seq.from_pretrained(model_config.model_name_or_path, **model_kwargs) + else: + model_ref = None + processor = AutoProcessor.from_pretrained(model_config.model_name_or_path, do_image_splitting=False) + tokenizer = processor.tokenizer + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + if args.ignore_bias_buffers: + # torch distributed hack + model._ddp_params_and_buffers_to_ignore = [ + name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool + ] + + ################ + # Optional rich context managers + ############### + init_context = nullcontext() if not TRL_USE_RICH else console.status("[bold green]Initializing the DPOTrainer...") + save_context = ( + nullcontext() + if not TRL_USE_RICH + else console.status(f"[bold green]Training completed! Saving the model to {training_args.output_dir}") + ) + + ################ + # Dataset + ################ + ds = load_dataset(args.dataset_name) + if args.sanity_check: + for key in ds: + ds[key] = ds[key].select(range(50)) + + def process(row): + # The prompt can be either a string or a list. In some datasets, the prompt is just a common string + # for both rejected and chosen (already included in chosen and rejected) and is not meant to be used + # separately. In other datasets, the prompt is intended to be used as a prefix for rejected and chosen, + # and in such cases, it is properly formatted as a list with keys "role" and "content". + # Example 1: + # row = {"prompt": "What does detox mean?", + # "chosen": [{"content": "What does detox mean?", "role": "user"}, {"content": "It means to get rid of the toxins.", "role": "assistant"}], + # "rejected": [{"content": "What does detox mean?", "role": "assistant"}, {"content": "I don't know.", "role": "user"}]} + # Example 2: + # row = {"prompt": [{"content": "What does detox mean?", "role": "user"}], + # "chosen": [{"content": "It means to get rid of the toxins.", "role": "assistant"}], + # "rejected": [{"content": "I don't know.", "role": "user"}]} + if "prompt" in row and isinstance(row["prompt"], list): + row["prompt"] = processor.apply_chat_template(row["prompt"], tokenize=False) + + row["chosen"] = processor.apply_chat_template(row["chosen"], tokenize=False) + row["rejected"] = processor.apply_chat_template(row["rejected"], tokenize=False) + + if "images" in row: + for idx, img in enumerate(row["images"]): # Resize each image so the largest side is 640 pixels + ratio = min(1.0, 640 / max(img.size)) + new_size = (int(img.size[0] * ratio), int(img.size[1] * ratio)) + row["images"][idx] = img.resize(new_size) + row["images"] = row["images"] + + return row + + with PartialState().local_main_process_first(): + ds = ds.map(process, num_proc=training_args.dataset_num_proc) + train_dataset = ds[args.dataset_train_split] + eval_dataset = ds[args.dataset_test_split] + + ################ + # Training + ################ + with init_context: + trainer = DPOTrainer( + model, + model_ref, + args=training_args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + tokenizer=processor, + peft_config=get_peft_config(model_config), + callbacks=[RichProgressCallback] if TRL_USE_RICH else None, + ) + + trainer.train() + trainer.push_to_hub + with save_context: + trainer.save_model(training_args.output_dir) diff --git a/setup.py b/setup.py index 65ae9293d5..7180babb93 100644 --- a/setup.py +++ b/setup.py @@ -69,7 +69,16 @@ "tyro>=0.5.11", ] EXTRAS = { - "test": ["parameterized", "pytest", "pytest-xdist", "accelerate", "pytest-cov", "pytest-xdist", "scikit-learn"], + "test": [ + "parameterized", + "pytest", + "pytest-xdist", + "accelerate", + "pytest-cov", + "pytest-xdist", + "scikit-learn", + "Pillow", + ], "peft": ["peft>=0.4.0"], "diffusers": ["diffusers>=0.18.0"], "deepspeed": ["deepspeed>=0.9.5"], diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index 703d6af2e8..ae049b5491 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -15,10 +15,17 @@ import unittest import torch -from datasets import Dataset +from datasets import Dataset, features from parameterized import parameterized +from PIL import Image from pytest import mark -from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer +from transformers import ( + AutoModelForCausalLM, + AutoModelForSeq2SeqLM, + AutoModelForVision2Seq, + AutoProcessor, + AutoTokenizer, +) from trl import DPOConfig, DPOTrainer, FDivergenceType @@ -40,6 +47,12 @@ def setUpClass(cls): cls.t5_ref_model = AutoModelForSeq2SeqLM.from_pretrained(model_id) cls.t5_tokenizer = AutoTokenizer.from_pretrained(model_id) + # get idefics2 model + model_id = "trl-internal-testing/tiny-random-idefics2" + cls.idefics2_model = AutoModelForVision2Seq.from_pretrained(model_id) + cls.idefics2_ref_model = AutoModelForVision2Seq.from_pretrained(model_id) + cls.idefics2_processor = AutoProcessor.from_pretrained(model_id) + def _init_dummy_dataset(self): # fmt: off dummy_dataset_dict = { @@ -80,6 +93,57 @@ def _init_dummy_dataset(self): # fmt: on return Dataset.from_dict(dummy_dataset_dict) + def _init_dummy_image_dataset(self): + # fmt: off + dummy_dataset_dict = { + "images": [ + [Image.new("RGB", (100, 50), color="black")], + # None, + # [Image.new("RGB", (100, 100), color="blue"), Image.new("RGB", (150, 50), color="red")], + [Image.new("RGB", (200, 100), color="green")], + # [Image.new("RGB", (150, 150), color="yellow"), Image.new("RGB", (50, 150), color="purple")], + [Image.new("RGB", (80, 120), color="gray")], + [Image.new("RGB", (120, 80), color="pink")], + ], + "prompt": [ + " Hello", + # "How are you?", + # " Let's chat", + " Good morning", + # " What's up?", + "Can you see this? ", + "Here is something interesting: ", + ], + "chosen": [ + "Hi nice to meet you!", + # "I'm doing well, thank you!", + # "Sure, let's talk!", + "Good morning to you too!", + # "Not much, just working.", + "Yes, I can see it clearly.", + "That's quite interesting indeed.", + ], + "rejected": [ + "Leave me alone!", + # "I'm not interested.", + # "I don't want to chat.", + "I'm still sleepy.", + # "Busy right now, talk later.", + "No, I can't see it.", + "I'm not sure what that is.", + ], + } + # fmt: on + f = features.Features( + { + "images": features.Sequence(features.Image(decode=True)), # datasets handles badly sequence of images + "prompt": features.Value("string"), + "chosen": features.Value("string"), + "rejected": features.Value("string"), + } + ) + return Dataset.from_dict(dummy_dataset_dict, features=f) + @parameterized.expand( [ ["gpt2", "sigmoid", True], @@ -152,6 +216,54 @@ def test_dpo_trainer(self, name, loss_type, pre_compute): if param.sum() != 0: assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12) + @parameterized.expand( + [ + ["sigmoid", True], + ] + ) + def test_vdpo_trainer(self, loss_type, pre_compute): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = DPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=1, + learning_rate=9e-1, + eval_strategy="steps", + beta=0.1, + loss_type=loss_type, + precompute_ref_log_probs=pre_compute, + ) + + dummy_dataset = self._init_dummy_image_dataset() + + model = self.idefics2_model + ref_model = self.idefics2_ref_model + processor = self.idefics2_processor + + trainer = DPOTrainer( + model=model, + ref_model=ref_model, + args=training_args, + tokenizer=processor, + train_dataset=dummy_dataset, + eval_dataset=dummy_dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + # check the params have changed - ignore 0 biases + if param.sum() != 0: + assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12) + def test_dpo_trainer_without_providing_ref_model(self): with tempfile.TemporaryDirectory() as tmp_dir: training_args = DPOConfig( @@ -811,3 +923,7 @@ def test_dpo_loss_js_div_f(self): policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps ) assert torch.isfinite(losses).cpu().numpy().all() + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000000..5e5c3ec9c9 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,57 @@ +import unittest + +import torch + +from trl.trainer.utils import pad + + +class TestPad(unittest.TestCase): + def test_pad_1_dim_left(self): + x = torch.tensor([1, 2, 3]) + y = torch.tensor([4, 5]) + output = pad((x, y), padding_value=0, padding_side="left") + expected = torch.tensor([[1, 2, 3], [0, 4, 5]]) + self.assertTrue(torch.equal(output, expected)) + + def test_pad_1_dim_right(self): + x = torch.tensor([1, 2, 3]) + y = torch.tensor([4, 5]) + output = pad((x, y), padding_value=0, padding_side="right") + expected = torch.tensor([[1, 2, 3], [4, 5, 0]]) + self.assertTrue(torch.equal(output, expected)) + + def test_pad_2_dim_left(self): + x = torch.tensor([[1, 2], [3, 4]]) + y = torch.tensor([[5, 6]]) + output = pad((x, y), padding_value=0, padding_side="left") + expected = torch.tensor( + [ + [[1, 2], [3, 4]], + [[0, 0], [5, 6]], + ] + ) + self.assertTrue(torch.equal(output, expected)) + + def test_pad_2_dim_right(self): + x = torch.tensor([[1, 2], [3, 4]]) + y = torch.tensor([[5, 6]]) + output = pad((x, y), padding_value=0, padding_side="right") + expected = torch.tensor( + [ + [[1, 2], [3, 4]], + [[5, 6], [0, 0]], + ] + ) + self.assertTrue(torch.equal(output, expected)) + + def test_pad_2_dim_right_multidim(self): + x = torch.tensor([[1, 2], [3, 4]]) + y = torch.tensor([[5]]) + output = pad((x, y), padding_value=0, padding_side="right") + expected = torch.tensor( + [ + [[1, 2], [3, 4]], + [[5, 0], [0, 0]], + ] + ) + self.assertTrue(torch.equal(output, expected)) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 0fee33f22a..bf30d1fd66 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -37,6 +37,7 @@ PreTrainedTokenizerBase, Trainer, ) +from transformers.models.auto.modeling_auto import MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES from transformers.trainer_callback import TrainerCallback from transformers.trainer_utils import EvalLoopOutput @@ -316,6 +317,20 @@ def make_inputs_require_grad(module, input, output): else: self.is_encoder_decoder = args.is_encoder_decoder + if model is not None: + self.is_vision_model = model.config.model_type in MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES.keys() + else: + warnings.warn( + "No model provided, cannot determine if it is a vision model. Setting is_vision_model to False." + ) + self.is_vision_model = False + + if self.is_vision_model: + self.processor = tokenizer + self.tokenizer = tokenizer.tokenizer # tokenizer is actually a processor at this point + else: + self.tokenizer = tokenizer + self.is_peft_model = is_peft_available() and isinstance(model, PeftModel) if model_adapter_name is not None: warnings.warn( @@ -401,7 +416,7 @@ def make_inputs_require_grad(module, input, output): args.label_pad_token_id = label_pad_token_id if data_collator is None: data_collator = DPODataCollatorWithPadding( - pad_token_id=tokenizer.pad_token_id, + pad_token_id=self.tokenizer.pad_token_id, label_pad_token_id=args.label_pad_token_id, is_encoder_decoder=self.is_encoder_decoder, ) @@ -437,7 +452,7 @@ def make_inputs_require_grad(module, input, output): "You passed `padding_value` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`." ) args.padding_value = padding_value - self.padding_value = args.padding_value if padding_value is not None else tokenizer.pad_token_id + self.padding_value = args.padding_value if padding_value is not None else self.tokenizer.pad_token_id self.max_prompt_length = args.max_prompt_length if truncation_mode != "keep_end": warnings.warn( @@ -446,7 +461,6 @@ def make_inputs_require_grad(module, input, output): args.truncation_mode = truncation_mode self.truncation_mode = args.truncation_mode self.max_target_length = args.max_target_length - self.tokenizer = tokenizer self.precompute_ref_log_probs = args.precompute_ref_log_probs # Since ref_logs are precomputed on the first call to get_train/eval_dataloader @@ -496,10 +510,12 @@ def make_inputs_require_grad(module, input, output): # Compute that only on the main process for faster data processing. # see: https://github.com/huggingface/trl/pull/1255 with PartialState().local_main_process_first(): - # tokenize the dataset - train_dataset = train_dataset.map(self.tokenize_row, num_proc=self.dataset_num_proc) + # tokenize the dataset, lower writer batch size to avoid OOM (frequent in vision models) + train_dataset = train_dataset.map(self.tokenize_row, num_proc=self.dataset_num_proc, writer_batch_size=10) if eval_dataset is not None: - eval_dataset = eval_dataset.map(self.tokenize_row, num_proc=self.dataset_num_proc) + eval_dataset = eval_dataset.map( + self.tokenize_row, num_proc=self.dataset_num_proc, writer_batch_size=10 + ) super().__init__( model=model, @@ -682,16 +698,22 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa return super().get_eval_dataloader(eval_dataset=eval_dataset) - def build_tokenized_answer(self, prompt, answer): + def build_tokenized_answer(self, prompt, answer, images=None): """ Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`. It does ensure `enc(a + b) = enc(a) + enc(a + b)[len(enc(a)):]`. Reference: https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257 """ - - full_tokenized = self.tokenizer(prompt + answer, add_special_tokens=False) - prompt_input_ids = self.tokenizer(prompt, add_special_tokens=False)["input_ids"] + if self.is_vision_model: + if answer.count("") > 0: + raise NotImplementedError("Answer contains token, which is not supported yet.") + full_tokenized = self.processor(prompt + answer, images=images, add_special_tokens=False) + full_tokenized = {k: v[0] for k, v in full_tokenized.items()} # Unbatch, not done when using idefics + prompt_input_ids = self.processor(prompt, images=images, add_special_tokens=False)["input_ids"][0] + else: + full_tokenized = self.tokenizer(prompt + answer, add_special_tokens=False) + prompt_input_ids = self.tokenizer(prompt, add_special_tokens=False)["input_ids"] answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids) :] answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids) :] @@ -725,12 +747,22 @@ def build_tokenized_answer(self, prompt, answer): answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:] answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:] - return dict( - prompt_input_ids=prompt_input_ids, - prompt_attention_mask=prompt_attention_mask, - input_ids=answer_input_ids, - attention_mask=answer_attention_mask, - ) + if "pixel_values" in full_tokenized: + return dict( + prompt_input_ids=prompt_input_ids, + prompt_attention_mask=prompt_attention_mask, + prompt_pixel_values=full_tokenized["pixel_values"], + prompt_pixel_attention_mask=full_tokenized["pixel_attention_mask"], + input_ids=answer_input_ids, + attention_mask=answer_attention_mask, + ) + else: + return dict( + prompt_input_ids=prompt_input_ids, + prompt_attention_mask=prompt_attention_mask, + input_ids=answer_input_ids, + attention_mask=answer_attention_mask, + ) def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module]] = None) -> Dict: """Tokenize a single row from a DPO specific dataset. @@ -747,6 +779,7 @@ def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module prompt = feature["prompt"] chosen = feature["chosen"] rejected = feature["rejected"] + images = feature.get("images") if not self.is_encoder_decoder: # Check issues below for more details @@ -756,16 +789,22 @@ def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module if not isinstance(prompt, str): raise ValueError(f"prompt should be an str but got {type(prompt)}") - prompt_tokens = self.tokenizer(prompt, add_special_tokens=False) + if self.is_vision_model: + prompt_tokens = self.processor(prompt, images=images, add_special_tokens=False) + prompt_tokens = {k: v[0] for k, v in prompt_tokens.items()} # Unbatch, not done when using idefics + else: + prompt_tokens = self.tokenizer(prompt, add_special_tokens=False) + prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()} if not isinstance(chosen, str): raise ValueError(f"chosen should be an str but got {type(chosen)}") - chosen_tokens = self.build_tokenized_answer(prompt, chosen) + + chosen_tokens = self.build_tokenized_answer(prompt, chosen, images) if not isinstance(rejected, str): raise ValueError(f"rejected should be an str but got {type(rejected)}") - rejected_tokens = self.build_tokenized_answer(prompt, rejected) + rejected_tokens = self.build_tokenized_answer(prompt, rejected, images) # Last prompt token might get merged by tokenizer and # it should not be included for generation if that happens @@ -925,6 +964,7 @@ def compute_reference_log_probs(self, padded_batch: Dict) -> Dict: def concatenated_inputs( batch: Dict[str, Union[List, torch.LongTensor]], is_encoder_decoder: bool = False, + is_vision_model: bool = False, label_pad_token_id: int = -100, padding_value: int = 0, device: Optional[torch.device] = None, @@ -981,6 +1021,11 @@ def concatenated_inputs( batch["prompt_attention_mask"].repeat(2, 1).to(device=device) ) + if is_vision_model: + concatenated_batch["pixel_values"] = batch["prompt_pixel_values"].repeat(2, 1, 1, 1, 1).to(device=device) + concatenated_batch["pixel_attention_mask"] = ( + batch["prompt_pixel_attention_mask"].repeat(2, 1, 1, 1).to(device=device) + ) return concatenated_batch def dpo_loss( @@ -1187,20 +1232,23 @@ def concatenated_forward( concatenated_batch = self.concatenated_inputs( batch, is_encoder_decoder=self.is_encoder_decoder, + is_vision_model=self.is_vision_model, label_pad_token_id=self.label_pad_token_id, padding_value=self.padding_value, device=self.accelerator.device, ) len_chosen = batch["chosen_labels"].shape[0] - model_kwargs = ( - { - "labels": concatenated_batch["concatenated_labels"], - "decoder_input_ids": concatenated_batch.pop("concatenated_decoder_input_ids", None), - } - if self.is_encoder_decoder - else {} - ) + model_kwargs = {} + + if self.is_encoder_decoder: + model_kwargs["labels"] = concatenated_batch["concatenated_labels"] + model_kwargs["decoder_input_ids"] = concatenated_batch.pop("concatenated_decoder_input_ids", None) + + if self.is_vision_model: + model_kwargs["pixel_values"] = concatenated_batch["pixel_values"] + model_kwargs["pixel_attention_mask"] = concatenated_batch["pixel_attention_mask"] + if self.aux_loss_enabled: model_kwargs["output_router_logits"] = True diff --git a/trl/trainer/model_config.py b/trl/trainer/model_config.py index c30fa4ae49..b16a07421d 100644 --- a/trl/trainer/model_config.py +++ b/trl/trainer/model_config.py @@ -86,5 +86,5 @@ def __post_init__(self): if self.load_in_8bit and self.load_in_4bit: raise ValueError("You can't use 8 bit and 4 bit precision at the same time") - if self.lora_target_modules == ["all-linear"]: - self.lora_target_modules = "all-linear" + if isinstance(self.lora_target_modules, list) and len(self.lora_target_modules) == 1: + self.lora_target_modules = self.lora_target_modules[0] diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 19531eaf54..aa44c21ce2 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -336,6 +336,55 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: return batch +def pad(tensors: List[torch.Tensor], padding_value: int = 0, padding_side: str = "right") -> torch.Tensor: + """ + Pads a list of tensors to the same shape along the first dimension. + + Args: + tensors (`List[torch.Tensor]`): + List of input tensors to pad. + padding_value (`int`): + Value to use for padding. Default is 0. + padding_side (`str`): + Side on which to add padding. Must be 'left' or 'right'. Default is 'right'. + + Returns: + `torch.Tensor`: + A single tensor containing the padded tensors. + + Examples: + >>> import torch + >>> pad([torch.tensor([1, 2, 3]), torch.tensor([4, 5])]) + tensor([[1, 2, 3], + [4, 5, 0]]) + >>> pad([torch.tensor([[1, 2], [3, 4]]), torch.tensor([[5, 6]])]) + tensor([[[1, 2], + [3, 4]], + + [[5, 6], + [0, 0]]]) + """ + # Determine the maximum shape for each dimension + output_shape = np.max([t.shape for t in tensors], 0).tolist() + + # Create an output tensor filled with the padding value + output = torch.full((len(tensors), *output_shape), padding_value, dtype=tensors[0].dtype, device=tensors[0].device) + + for i, t in enumerate(tensors): + # Determine the slice for the sequence dimension + if padding_side == "left": + seq_slice = slice(output_shape[0] - t.shape[0], output_shape[0]) + elif padding_side == "right": + seq_slice = slice(0, t.shape[0]) + else: + raise ValueError("padding_side must be 'left' or 'right'") + + slices = (seq_slice,) + tuple(slice(0, s) for s in t.shape[1:]) + output[i][slices] = t + + return output + + @dataclass class DPODataCollatorWithPadding: r""" @@ -357,7 +406,7 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: # first, pad everything to the same length padded_batch = {} for k in features[0].keys(): - if k.endswith("_input_ids") or k.endswith("_attention_mask") or k.endswith("_labels"): + if k.endswith(("_input_ids", "_attention_mask", "_labels", "_pixel_values")): if self.is_encoder_decoder: to_pad = [torch.LongTensor(ex[k]) for ex in features] @@ -377,11 +426,7 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: raise ValueError(f"Unexpected key in batch '{k}'") padded_batch[k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value) else: - # adapted from https://stackoverflow.com/questions/73256206 - if "prompt" in k: - to_pad = [torch.LongTensor(ex[k][::-1]) for ex in features] - else: - to_pad = [torch.LongTensor(ex[k]) for ex in features] + # Set padding value based on the key if k.endswith("_input_ids"): if self.pad_token_id is None: raise ValueError( @@ -394,13 +439,26 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: padding_value = self.label_pad_token_id elif k.endswith("_attention_mask"): padding_value = 0 + elif k.endswith("_pixel_values"): + padding_value = 0 # TODO: check if this is correct else: raise ValueError(f"Unexpected key in batch '{k}'") - padded_batch[k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value) - # for the prompt, flip back so padding is on left side - if "prompt" in k: - padded_batch[k] = padded_batch[k].flip(dims=[1]) + # Set padding side based on the key + if k in ["prompt_input_ids", "prompt_attention_mask"]: + padding_side = "left" + else: + padding_side = "right" + + # Set the dtype + if k.endswith("_pixel_values"): + dtype = torch.float32 # will be downcasted if necessary by the Trainer + else: + dtype = torch.int64 + + # Convert to tensor and pad + to_pad = [torch.tensor(ex[k], dtype=dtype) for ex in features] + padded_batch[k] = pad(to_pad, padding_value=padding_value, padding_side=padding_side) elif k.endswith("_logps"): # the cached reference model logprobs padded_batch[k] = torch.tensor([ex[k] for ex in features])