diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index e615caf0431..02ef4d1b039 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -39,6 +39,10 @@ def formatting_prompts_func(example): return text +def formatting_func_for_pretokenized(example): + return example["input_ids"] + + def formatting_prompts_func_batched(example): output_text = [] for i, question in enumerate(example["question"]): @@ -93,6 +97,17 @@ def setUp(self): ], } ) + self.dummy_tokenized_dataset = Dataset.from_dict( + { + "input_ids": [ + self.tokenizer.encode( + "TRL is a library to post-train LLMs and diffusion models with methods such as Supervised Fine-tuning (SFT), Proximal Policy Optimization (PPO), and Direct Preference Optimization (DPO)." + ) + ] + * 10 + } + ) + self.conversational_lm_dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling") self.standard_prompt_completion_dataset = load_dataset( "trl-internal-testing/zen", "standard_prompt_completion" @@ -105,7 +120,10 @@ def setUp(self): [ { "role": "user", - "content": [{"type": "text", "text": "What is in this image?"}, {"type": "image"}], + "content": [ + {"type": "text", "text": "What is in this image?"}, + {"type": "image"}, + ], }, { "role": "assistant", @@ -113,7 +131,12 @@ def setUp(self): }, { "role": "user", - "content": [{"type": "text", "text": "Oh ye, you are right, what is 1+1"}], + "content": [ + { + "type": "text", + "text": "Oh ye, you are right, what is 1+1", + } + ], }, { "role": "assistant", @@ -123,7 +146,10 @@ def setUp(self): [ { "role": "user", - "content": [{"type": "text", "text": "What is in this image?"}, {"type": "image"}], + "content": [ + {"type": "text", "text": "What is in this image?"}, + {"type": "image"}, + ], }, { "role": "assistant", @@ -158,6 +184,45 @@ def setUp(self): num_of_sequences=16, ) + self.train_dataset_from_pretokenized = ConstantLengthDataset( + self.tokenizer, + self.dummy_tokenized_dataset, + dataset_text_field=None, + seq_length=16, + num_of_sequences=16, + formatting_func=formatting_func_for_pretokenized, + ) + + self.eval_dataset_from_pretokenized = ConstantLengthDataset( + self.tokenizer, + self.dummy_tokenized_dataset, + dataset_text_field=None, + seq_length=16, + num_of_sequences=16, + formatting_func=formatting_func_for_pretokenized, + ) + + def test_constant_length_dataset_with_pretokenized_data(self): + constant_len_dataset = ConstantLengthDataset( + self.tokenizer, + self.dummy_tokenized_dataset, + dataset_text_field=None, + formatting_func=formatting_func_for_pretokenized, + ) + + assert len(constant_len_dataset) == len(self.dummy_tokenized_dataset) + assert len(constant_len_dataset) > 0 + + for example in constant_len_dataset: + assert "input_ids" in example + assert "labels" in example + + assert len(example["input_ids"]) == constant_len_dataset.seq_length + assert len(example["labels"]) == constant_len_dataset.seq_length + + decoded_text = self.tokenizer.decode(example["input_ids"]) + assert ("TRL" in decoded_text) and ("(DPO)" in decoded_text) + def test_constant_length_dataset(self): formatted_dataset = ConstantLengthDataset( self.tokenizer, @@ -236,6 +301,34 @@ def test_sft_trainer(self): self.assertIn("model.safetensors", os.listdir(tmp_dir + "/checkpoint-2")) + def test_sft_trainer_with_pretokenzied_data_packing(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = SFTConfig( + output_dir=tmp_dir, + dataloader_drop_last=True, + eval_strategy="steps", + max_steps=4, + eval_steps=2, + save_steps=2, + per_device_train_batch_size=2, + packing=True, + report_to="none", + ) + + trainer = SFTTrainer( + model=self.model_id, + args=training_args, + train_dataset=self.train_dataset_from_pretokenized, + eval_dataset=self.eval_dataset_from_pretokenized, + ) + + trainer.train() + + assert trainer.state.log_history[(-1)]["train_loss"] is not None + assert trainer.state.log_history[0]["eval_loss"] is not None + + assert "model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2") + def test_sft_trainer_uncorrect_data(self): with tempfile.TemporaryDirectory() as tmp_dir: # Shoud work as SFTTrainer natively supports conversational lm dataset diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index c9ea619d63f..f2854a1cec0 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -622,8 +622,8 @@ def __init__( else: # neither is provided raise ValueError("Either `dataset_text_field` or `formatting_func` should be provided.") - if self.formatting_func is not None: - if self.formatting_func.__code__.co_argcount > 1: + if formatting_func is not None: + if formatting_func.__code__.co_argcount > 1: warnings.warn( "The passed formatting_func has more than one argument. Usually that function should have a single argument `example`" " which corresponds to the dictionary returned by each element of the dataset. Make sure you know what you are doing."