Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Padding free dpo #2437

Closed
wants to merge 58 commits into from
Closed
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
ca99954
added eos token for ppotrainer
dame-cell Nov 30, 2024
fe1d5f6
remove the unnecessary stuff
dame-cell Nov 30, 2024
b15c635
Update ppo_config.py
dame-cell Nov 30, 2024
1bcb3a4
remove redundant EOS token fallback
dame-cell Dec 1, 2024
2ef2b24
remove redundant EOS token fallback
dame-cell Dec 1, 2024
42a0f73
remove some unnecessary tests stuff
dame-cell Dec 1, 2024
6130a91
added tests and update concatenated_inputs
dame-cell Dec 4, 2024
6732ed2
return only list and also a lot to do
dame-cell Dec 4, 2024
ce67292
padding free not tested but getting closer
dame-cell Dec 10, 2024
91e40aa
rebase and also reevaluate my approach
dame-cell Dec 10, 2024
8a34cb5
merge main
dame-cell Dec 10, 2024
1d38632
fix identation
dame-cell Dec 10, 2024
814d69e
better tests
dame-cell Dec 10, 2024
7dae607
concatenated_forward now supports padding_free
dame-cell Dec 10, 2024
1a59d74
collator now does not return attention masks
dame-cell Dec 10, 2024
562c52e
postion ids and no attention mask works
dame-cell Dec 11, 2024
d194054
update concatenated forward to support padding_free
dame-cell Dec 11, 2024
3855851
update concatenated forward to support padding_free
dame-cell Dec 11, 2024
d9adbfb
Merge branch 'main' into padding_free_dpo
dame-cell Dec 11, 2024
1145006
grad accumalation tests
dame-cell Dec 13, 2024
24f73a4
Merge branch 'padding_free_dpo' of https://github.com/dame-cell/trl i…
dame-cell Dec 13, 2024
f6bd9e1
Resolved merge conflict in ppo_trainer.py
dame-cell Dec 13, 2024
bbd99cf
Resolved merge conflict in ppo_trainer.py
dame-cell Dec 13, 2024
ba4969d
Resolved merge conflict in ppo_trainer.py
dame-cell Dec 13, 2024
187b1e5
Resolved merge conflict in ppo_trainer.py
dame-cell Dec 13, 2024
1d9ce3e
fix identation
dame-cell Dec 13, 2024
8a974cc
comments update
dame-cell Dec 13, 2024
7f0298b
fix some small issue
dame-cell Dec 13, 2024
2900275
fix some small issue
dame-cell Dec 13, 2024
58e779a
fix some small issue
dame-cell Dec 13, 2024
6a1e251
fix some small issue
dame-cell Dec 13, 2024
f92e056
update concatenate_forward to support padding_fre
dame-cell Dec 13, 2024
457e3a1
fix some small issue
dame-cell Dec 13, 2024
f1789f4
fix some small issue
dame-cell Dec 13, 2024
0103728
So we need to make sure to correctlty handle the list
dame-cell Dec 13, 2024
5837faa
by correclty updatuing concatenated_forward it works now
dame-cell Dec 13, 2024
0321c1d
refactoring concatenated_forward and batched same length seq for padd…
dame-cell Dec 14, 2024
a6e2163
update
dame-cell Dec 14, 2024
1328fc3
padding_free in concatenated_forward and update_test
dame-cell Dec 17, 2024
51a2cc6
Merge branch 'main' into padding_free_dpo
dame-cell Dec 17, 2024
986ed71
padding_free in concatenated_forward and update_test
dame-cell Dec 17, 2024
570b79a
padding_free in concatenated_forward and update_test
dame-cell Dec 17, 2024
9dd9564
Merge branch 'padding_free_dpo' of https://github.com/dame-cell/trl i…
dame-cell Dec 17, 2024
b781876
Merge branch 'main' into padding_free_dpo
dame-cell Dec 18, 2024
525ecb2
Merge branch 'main' into padding_free_dpo
dame-cell Dec 19, 2024
c8ce9c8
Merge branch 'main' into padding_free_dpo
dame-cell Dec 19, 2024
b7fad73
Reverted PPO trainer to original version and updated DPO files
dame-cell Dec 19, 2024
55cd219
Merge branch 'padding_free_dpo' of https://github.com/dame-cell/trl i…
dame-cell Dec 19, 2024
5e8df69
Updated DPO files
dame-cell Dec 19, 2024
ba1ded1
Merge branch 'main' into padding_free_dpo
dame-cell Dec 20, 2024
0784202
Merge branch 'main' into padding_free_dpo
dame-cell Dec 21, 2024
ddfed7c
update test_dpo_trainer.py
dame-cell Dec 21, 2024
955c7e8
update dpo_trainer.py
dame-cell Dec 21, 2024
ba4356e
update dpo_trainer.py
dame-cell Dec 21, 2024
c61abb5
update dpo_trainer.py
dame-cell Dec 21, 2024
64e9909
update dpo_trainer.py
dame-cell Dec 22, 2024
68186e7
Merge branch 'main' into padding_free_dpo
dame-cell Dec 22, 2024
90f50e4
Merge branch 'main' into padding_free_dpo
dame-cell Jan 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/scripts/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,4 +169,4 @@ def tokenize(element):
if training_args.push_to_hub:
trainer.push_to_hub(dataset_name=script_args.dataset_name)

trainer.generate_completions()
trainer.generate_completions()
88 changes: 40 additions & 48 deletions tests/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,8 @@ def test_dpo_trainer_without_providing_ref_model_with_lora(self):
if param.sum() != 0:
self.assertFalse(torch.equal(param, new_param))



def test_dpo_trainer_padding_token_is_none(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = DPOConfig(
Expand Down Expand Up @@ -471,8 +473,10 @@ def test_dpo_trainer_padding_token_is_none(self):
)

trainer.train()



def test_dpo_trainer_w_dataset_num_proc(self):
dame-cell marked this conversation as resolved.
Show resolved Hide resolved
def test_dpo_trainer_padding_free(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = DPOConfig(
output_dir=tmp_dir,
Expand All @@ -483,74 +487,62 @@ def test_dpo_trainer_w_dataset_num_proc(self):
learning_rate=9e-1,
eval_strategy="steps",
beta=0.1,
dataset_num_proc=5,
report_to="none",
)

dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference")

tokenizer = AutoTokenizer.from_pretrained(self.model_id)
tokenizer.pad_token = None

with self.assertRaisesRegex(
ValueError,
expected_regex=r"Can't find `pad_token_id` in the `processing_class`. "
r"Explicitly set `tokenizer.pad_token` \(e.g. `tokenizer.pad_token = tokenizer.eos_token`\) "
r"before instantiating the trainer.",
):
trainer = DPOTrainer(
model=self.model,
ref_model=None,
args=training_args,
processing_class=tokenizer,
train_dataset=dummy_dataset["train"],
eval_dataset=dummy_dataset["test"],
)

trainer.train()
# Test with padding_free=True
trainer_padding_free = DPOTrainer(
model=self.model,
ref_model=None,
args=training_args,
tokenizer=self.tokenizer,
padding_free=True,
train_dataset=dummy_dataset["train"],
eval_dataset=dummy_dataset["test"],
)

def test_tr_dpo_trainer(self):
batch_paddingfree = next(iter(trainer_padding_free.get_train_dataloader()))
# Check for correct keys in padding-free format
self.assertIn("prompt_input_ids", batch_paddingfree)
self.assertIn("chosen_input_ids", batch_paddingfree)
self.assertIn("rejected_input_ids", batch_paddingfree)
self.assertIn("prompt_position_ids", batch_paddingfree)
self.assertIn("chosen_position_ids", batch_paddingfree)
self.assertIn("rejected_position_ids", batch_paddingfree)
# Attention masks should not be present in padding-free mode
self.assertNotIn("prompt_attention_mask", batch_paddingfree)
self.assertNotIn("chosen_attention_mask", batch_paddingfree)
self.assertNotIn("rejected_attention_mask", batch_paddingfree)

def test_dpo_trainer_padding_free_training(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = DPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
remove_unused_columns=False,
gradient_accumulation_steps=4,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be good to have tests for this with gradient accumulation too. perhaps using pytest.mark.parameterize?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All right will do so thanks for reviewing 😎

gradient_accumulation_steps=1,
learning_rate=9e-1,
eval_strategy="steps",
precompute_ref_log_probs=False,
sync_ref_model=True,
ref_model_mixup_alpha=0.5,
ref_model_sync_steps=1,
beta=0.1,
report_to="none",
)

dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference")

dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference")
trainer = DPOTrainer(
model=self.model,
ref_model=self.ref_model,
args=training_args,
processing_class=self.tokenizer,
train_dataset=dummy_dataset["train"],
eval_dataset=dummy_dataset["test"],
)
model=self.model,
ref_model=None,
args=training_args,
tokenizer=self.tokenizer,
padding_free=True,
train_dataset=dummy_dataset["train"],
eval_dataset=dummy_dataset["test"],
)

# params of the ref model as its the same as the model
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

trainer.train()

self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])

# check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.ref_model.get_parameter(n)
# check the ref model's params have changed - ignore 0 biases
if param.sum() != 0:
self.assertFalse(torch.equal(param, new_param))

@require_no_wandb
def test_dpo_trainer_generate_during_eval_no_wandb(self):
with tempfile.TemporaryDirectory() as tmp_dir:
Expand Down
Loading