Skip to content

Commit

Permalink
feat: add tests for pretokenized dataset packing
Browse files Browse the repository at this point in the history
Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
  • Loading branch information
kmehant committed Oct 8, 2024
1 parent 4564ef8 commit 0d0f34c
Showing 1 changed file with 89 additions and 3 deletions.
92 changes: 89 additions & 3 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,17 @@ def setUp(self):
],
}
)
self.dummy_tokenized_dataset = Dataset.from_dict(
{
"input_ids": [
self.tokenizer.encode(
"TRL is a library to post-train LLMs and diffusion models with methods such as Supervised Fine-tuning (SFT), Proximal Policy Optimization (PPO), and Direct Preference Optimization (DPO)."
)
]
* 10
}
)

self.conversational_lm_dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling")
self.standard_prompt_completion_dataset = load_dataset(
"trl-internal-testing/zen", "standard_prompt_completion"
Expand All @@ -106,15 +117,23 @@ def setUp(self):
[
{
"role": "user",
"content": [{"type": "text", "text": "What is in this image?"}, {"type": "image"}],
"content": [
{"type": "text", "text": "What is in this image?"},
{"type": "image"},
],
},
{
"role": "assistant",
"content": [{"type": "text", "text": "It is random noise."}],
},
{
"role": "user",
"content": [{"type": "text", "text": "Oh ye, you are right, what is 1+1"}],
"content": [
{
"type": "text",
"text": "Oh ye, you are right, what is 1+1",
}
],
},
{
"role": "assistant",
Expand All @@ -124,7 +143,10 @@ def setUp(self):
[
{
"role": "user",
"content": [{"type": "text", "text": "What is in this image?"}, {"type": "image"}],
"content": [
{"type": "text", "text": "What is in this image?"},
{"type": "image"},
],
},
{
"role": "assistant",
Expand Down Expand Up @@ -159,6 +181,42 @@ def setUp(self):
num_of_sequences=16,
)

self.train_dataset_from_pretokenized = ConstantLengthDataset(
self.tokenizer,
self.dummy_tokenized_dataset,
dataset_text_field=None,
seq_length=16,
num_of_sequences=16,
)

self.eval_dataset_from_pretokenized = ConstantLengthDataset(
self.tokenizer,
self.dummy_tokenized_dataset,
dataset_text_field=None,
seq_length=16,
num_of_sequences=16,
)

def test_constant_length_dataset_with_pretokenized_data(self):
constant_len_dataset = ConstantLengthDataset(
self.tokenizer,
self.dummy_tokenized_dataset,
dataset_text_field=None,
)

assert len(constant_len_dataset) == len(self.dummy_tokenized_dataset)
assert len(constant_len_dataset) > 0

for example in constant_len_dataset:
assert "input_ids" in example
assert "labels" in example

assert len(example["input_ids"]) == constant_len_dataset.seq_length
assert len(example["labels"]) == constant_len_dataset.seq_length

decoded_text = self.tokenizer.decode(example["input_ids"])
assert ("TRL" in decoded_text) and ("(DPO)" in decoded_text)

def test_constant_length_dataset(self):
formatted_dataset = ConstantLengthDataset(
self.tokenizer,
Expand Down Expand Up @@ -237,6 +295,34 @@ def test_sft_trainer(self):

assert "model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")

def test_sft_trainer_with_pretokenzied_data_packing(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = SFTConfig(
output_dir=tmp_dir,
dataloader_drop_last=True,
eval_strategy="steps",
max_steps=4,
eval_steps=2,
save_steps=2,
per_device_train_batch_size=2,
packing=True,
report_to="none",
)

trainer = SFTTrainer(
model=self.model_id,
args=training_args,
train_dataset=self.train_dataset_from_pretokenized,
eval_dataset=self.eval_dataset_from_pretokenized,
)

trainer.train()

assert trainer.state.log_history[(-1)]["train_loss"] is not None
assert trainer.state.log_history[0]["eval_loss"] is not None

assert "model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")

def test_sft_trainer_uncorrect_data(self):
with tempfile.TemporaryDirectory() as tmp_dir:
# Shoud work as SFTTrainer natively supports conversational lm dataset
Expand Down

0 comments on commit 0d0f34c

Please sign in to comment.