Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add arg padding_free to DataCollatorForCompletionOnlyLM #1887

Merged
merged 12 commits into from
Aug 26, 2024
44 changes: 44 additions & 0 deletions tests/test_data_collator_completion_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,47 @@ def test_data_collator_handling_of_long_sequences(self):
encoded_instance = self.collator.torch_call([self.tokenized_instruction])
result = torch.all(encoded_instance["labels"] == -100)
assert result, "Not all values in the tensor are -100."

def test_padding_free(self):
tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/dummy-GPT2-correct-vocab")
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
inst1 = "### System: You are a helpful assistant.\n\n### User: How much is 2+2?\n\n### Assistant: 2+2 equals 4"
inst2 = "### System: You are a honest and helpful assistant.\n\n### User: What is the answer of 22x22?\n\n### Assistant: 22x22 equals 484"

response_template = "\n### Assistant:"
collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)
collator_paddingfree = DataCollatorForCompletionOnlyLM(
response_template, tokenizer=tokenizer, padding_free=True
)

tokenized_instruction = [tokenizer(x, add_special_tokens=False) for x in [inst1, inst2]]
batch = collator(tokenized_instruction)
batch_paddingfree = collator_paddingfree(tokenized_instruction)

self.assertNotIn("attention_mask", batch_paddingfree)
self.assertIn("input_ids", batch_paddingfree)
self.assertIn("labels", batch_paddingfree)
self.assertIn("position_ids", batch_paddingfree)
self.assertEqual(
batch_paddingfree["input_ids"].size(),
batch_paddingfree["labels"].size()
)
self.assertEqual(
batch_paddingfree["labels"].size(),
batch_paddingfree["position_ids"].size()
)
kashif marked this conversation as resolved.
Show resolved Hide resolved

attn_mask = batch["attention_mask"]
input_ids_remove_pad = batch["input_ids"][attn_mask.bool()].unsqueeze(0)
expected_position_ids = attn_mask.cumsum(1)[attn_mask.bool()].unsqueeze(0) - 1
expected_labels = []
for idx in range(batch["input_ids"].size(0)):
expected_labels.append(batch["labels"][idx][attn_mask[idx].bool()])
expected_labels[-1][0] = collator.ignore_index
expected_labels = torch.cat(expected_labels).unsqueeze(0)

self.assertTrue((input_ids_remove_pad == batch_paddingfree["input_ids"]).all())
self.assertTrue((expected_position_ids == batch_paddingfree["position_ids"]).all())
self.assertTrue((expected_labels == batch_paddingfree["labels"]).all())
10 changes: 10 additions & 0 deletions trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def __init__(
*args,
mlm: bool = False,
ignore_index: int = -100,
padding_free: bool = False,
**kwargs,
):
super().__init__(*args, mlm=mlm, **kwargs)
Expand Down Expand Up @@ -127,6 +128,7 @@ def __init__(
)

self.ignore_index = ignore_index
self.padding_free = padding_free

def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
batch = super().torch_call(examples)
Expand Down Expand Up @@ -211,6 +213,14 @@ def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> D
if len(response_token_ids_idxs) < len(human_token_ids_idxs):
batch["labels"][i, human_token_ids_idxs[-1] :] = self.ignore_index

if self.padding_free:
# remove padding, `attention_mask` and add `position_ids`
attn_mask = batch.pop("attention_mask")
batch["input_ids"] = batch["input_ids"][attn_mask.bool()].unsqueeze(0)
batch["position_ids"] = attn_mask.cumsum(1)[attn_mask.bool()].unsqueeze(0) - 1
batch["labels"] = batch["labels"][attn_mask.bool()].unsqueeze(0)
batch["labels"][batch["position_ids"] == 0] = self.ignore_index

return batch


Expand Down
Loading