Skip to content

Commit

Permalink
feat: add tests for pretokenized dataset packing
Browse files Browse the repository at this point in the history
Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
  • Loading branch information
kmehant committed Nov 5, 2024
1 parent 1ae3f93 commit b13e324
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 5 deletions.
99 changes: 96 additions & 3 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ def formatting_prompts_func(example):
return text


def formatting_func_for_pretokenized(example):
return example["input_ids"]


def formatting_prompts_func_batched(example):
output_text = []
for i, question in enumerate(example["question"]):
Expand Down Expand Up @@ -93,6 +97,17 @@ def setUp(self):
],
}
)
self.dummy_tokenized_dataset = Dataset.from_dict(
{
"input_ids": [
self.tokenizer.encode(
"TRL is a library to post-train LLMs and diffusion models with methods such as Supervised Fine-tuning (SFT), Proximal Policy Optimization (PPO), and Direct Preference Optimization (DPO)."
)
]
* 10
}
)

self.conversational_lm_dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling")
self.standard_prompt_completion_dataset = load_dataset(
"trl-internal-testing/zen", "standard_prompt_completion"
Expand All @@ -105,15 +120,23 @@ def setUp(self):
[
{
"role": "user",
"content": [{"type": "text", "text": "What is in this image?"}, {"type": "image"}],
"content": [
{"type": "text", "text": "What is in this image?"},
{"type": "image"},
],
},
{
"role": "assistant",
"content": [{"type": "text", "text": "It is random noise."}],
},
{
"role": "user",
"content": [{"type": "text", "text": "Oh ye, you are right, what is 1+1"}],
"content": [
{
"type": "text",
"text": "Oh ye, you are right, what is 1+1",
}
],
},
{
"role": "assistant",
Expand All @@ -123,7 +146,10 @@ def setUp(self):
[
{
"role": "user",
"content": [{"type": "text", "text": "What is in this image?"}, {"type": "image"}],
"content": [
{"type": "text", "text": "What is in this image?"},
{"type": "image"},
],
},
{
"role": "assistant",
Expand Down Expand Up @@ -158,6 +184,45 @@ def setUp(self):
num_of_sequences=16,
)

self.train_dataset_from_pretokenized = ConstantLengthDataset(
self.tokenizer,
self.dummy_tokenized_dataset,
dataset_text_field=None,
seq_length=16,
num_of_sequences=16,
formatting_func=formatting_func_for_pretokenized,
)

self.eval_dataset_from_pretokenized = ConstantLengthDataset(
self.tokenizer,
self.dummy_tokenized_dataset,
dataset_text_field=None,
seq_length=16,
num_of_sequences=16,
formatting_func=formatting_func_for_pretokenized,
)

def test_constant_length_dataset_with_pretokenized_data(self):
constant_len_dataset = ConstantLengthDataset(
self.tokenizer,
self.dummy_tokenized_dataset,
dataset_text_field=None,
formatting_func=formatting_func_for_pretokenized,
)

assert len(constant_len_dataset) == len(self.dummy_tokenized_dataset)
assert len(constant_len_dataset) > 0

for example in constant_len_dataset:
assert "input_ids" in example
assert "labels" in example

assert len(example["input_ids"]) == constant_len_dataset.seq_length
assert len(example["labels"]) == constant_len_dataset.seq_length

decoded_text = self.tokenizer.decode(example["input_ids"])
assert ("TRL" in decoded_text) and ("(DPO)" in decoded_text)

def test_constant_length_dataset(self):
formatted_dataset = ConstantLengthDataset(
self.tokenizer,
Expand Down Expand Up @@ -236,6 +301,34 @@ def test_sft_trainer(self):

self.assertIn("model.safetensors", os.listdir(tmp_dir + "/checkpoint-2"))

def test_sft_trainer_with_pretokenzied_data_packing(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = SFTConfig(
output_dir=tmp_dir,
dataloader_drop_last=True,
eval_strategy="steps",
max_steps=4,
eval_steps=2,
save_steps=2,
per_device_train_batch_size=2,
packing=True,
report_to="none",
)

trainer = SFTTrainer(
model=self.model_id,
args=training_args,
train_dataset=self.train_dataset_from_pretokenized,
eval_dataset=self.eval_dataset_from_pretokenized,
)

trainer.train()

assert trainer.state.log_history[(-1)]["train_loss"] is not None
assert trainer.state.log_history[0]["eval_loss"] is not None

assert "model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")

def test_sft_trainer_uncorrect_data(self):
with tempfile.TemporaryDirectory() as tmp_dir:
# Shoud work as SFTTrainer natively supports conversational lm dataset
Expand Down
4 changes: 2 additions & 2 deletions trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,8 +622,8 @@ def __init__(
else: # neither is provided
raise ValueError("Either `dataset_text_field` or `formatting_func` should be provided.")

if self.formatting_func is not None:
if self.formatting_func.__code__.co_argcount > 1:
if formatting_func is not None:
if formatting_func.__code__.co_argcount > 1:
warnings.warn(
"The passed formatting_func has more than one argument. Usually that function should have a single argument `example`"
" which corresponds to the dictionary returned by each element of the dataset. Make sure you know what you are doing."
Expand Down

0 comments on commit b13e324

Please sign in to comment.