-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Padding free dpo #2437
base: main
Are you sure you want to change the base?
Padding free dpo #2437
Conversation
not really done yet but for now for now here are some task to be done :
|
most of the stuff is done just some small stuff left like dealing with list and converting to tensor |
Hey @osanseviero, The main idea for using padding_free is mostly in place now, but there are still a few things that need to be done. It would be awesome if you could take a look at the code and let me know if there's anything else I should address or add. I've made it so the user can directly do this trainer = DPOTrainer(
model=self.model,
ref_model=None,
args=training_args,
tokenizer=self.tokenizer,
padding_free=True, # when true it will not use any padding
train_dataset=dummy_dataset["train"],
eval_dataset=dummy_dataset["test"],
) |
tests/test_dpo_trainer.py
Outdated
with tempfile.TemporaryDirectory() as tmp_dir: | ||
training_args = DPOConfig( | ||
output_dir=tmp_dir, | ||
per_device_train_batch_size=2, | ||
max_steps=3, | ||
remove_unused_columns=False, | ||
gradient_accumulation_steps=4, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be good to have tests for this with gradient accumulation too. perhaps using pytest.mark.parameterize
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All right will do so thanks for reviewing 😎
trl/trainer/ppo_config.py
Outdated
@@ -53,6 +53,9 @@ class PPOConfig(OnPolicyConfig): | |||
Discount factor. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This modif shouldn't be here, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh my bad i'll fix it right now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You still have modifications in ppo files
You can push it, no worry we can still refine after |
Thank you for your understanding! I wanted to let you know that I’m a bit tied up today and tomorrow, so I might not be able to push the code right away. I’ll try to get to it as soon as possible, but please feel free to let me know if there’s a hard deadline I should prioritize. Thanks for your patience! |
No rush on our side :) |
all right so I think this does it I did check if we can train this on a single T4 gpu colab notebook python trl/examples/scripts/dpo.py \
--dataset_name trl-lib/ultrafeedback_binarized \
--model_name_or_path Qwen/Qwen2-0.5B-Instruct \
--learning_rate 5.0e-6 \
--num_train_epochs 1 \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 8 \
--gradient_checkpointing \
--logging_steps 1 \
--output_dir Qwen2-0.5B-DPO \
--no_remove_unused_columns \
--use_peft \
--lora_r 32 \
--lora_alpha 16 without padding_free it kept saying OOM is this normal or what ? |
@osanseviero Just wanted to follow up on this PR and see if there’s any feedback so far. I’m happy to clarify anything or make updates if needed. Let me know whenever you get a chance—thanks so much for your time! 🙌 |
You still need to revert the changes applied to PPO files. And apply pre-commits |
@qgallouedec
here is the lasts steps metrics when training
There still appear to be noticeable differences between the Rewards/Chosen and Rewards/Rejected metrics. Despite my efforts to resolve this,I just could not fix it |
(Please stop tagging osanseviero? Unless you've a good reason. He is not involved here, please don't bother him 🙏) |
What does this PR do?
New feature #2422
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines.
For now this is just a draft will be continuing to work on it
@osanseviero