-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
π§βπ³ Add precompute batch size argument in DPOTrainer
for reference model
#2426
Conversation
also if you can kindly add this config to the tests where we test with |
Good point @kashif. We cannot really test if it works, but at least, we can check that it doesn't fail when this arg is passed. def test_precompute_ref_batch_size(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = DPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
precompute_ref_log_probs=True,
precompute_ref_batch_size=4,
report_to="none",
)
dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference")
trainer = DPOTrainer(
model=self.model,
ref_model=self.ref_model,
args=training_args,
processing_class=self.tokenizer,
train_dataset=dummy_dataset["train"],
eval_dataset=dummy_dataset["test"],
)
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
trainer.train()
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
# check the params have changed - ignore 0 biases
if param.sum() != 0:
self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)) |
DPOTrainer
for reference model
Co-authored-by: Quentin GallouΓ©dec <45557362+qgallouedec@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, I'll merge once the CI is green. Thanks again @SwayamInSync
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
What does this PR do?
Fixes #2421
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines.
Who can review?
@qgallouedec