Skip to content

Commit

Permalink
🧑‍🍳 Add precompute batch size argument in DPOTrainer for reference …
Browse files Browse the repository at this point in the history
…model (huggingface#2426)

* added precompute_batch

* review-fixes

* moving up

* Update trl/trainer/dpo_config.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update trl/trainer/dpo_config.py

* Update trl/trainer/dpo_config.py [ci skip]

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
  • Loading branch information
SwayamInSync and qgallouedec authored Dec 2, 2024
1 parent 148b592 commit f6f4265
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 2 deletions.
34 changes: 34 additions & 0 deletions tests/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,40 @@ def test_dpo_trainer_with_ref_model_is_model(self):
train_dataset=dummy_dataset["train"],
)

def test_precompute_ref_batch_size(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = DPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
precompute_ref_log_probs=True,
precompute_ref_batch_size=4,
report_to="none",
)

dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference")

trainer = DPOTrainer(
model=self.model,
ref_model=self.ref_model,
args=training_args,
processing_class=self.tokenizer,
train_dataset=dummy_dataset["train"],
eval_dataset=dummy_dataset["test"],
)

previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

trainer.train()

self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])

# check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
# check the params have changed - ignore 0 biases
if param.sum() != 0:
self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12))

@require_peft
def test_dpo_trainer_without_providing_ref_model_with_lora(self):
from peft import LoraConfig
Expand Down
5 changes: 5 additions & 0 deletions trl/trainer/dpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ class DPOConfig(TrainingArguments):
precompute_ref_log_probs (`bool`, *optional*, defaults to `False`):
Whether to precompute reference model log probabilities for training and evaluation datasets. This is
useful when training without the reference model to reduce the total GPU memory needed.
precompute_ref_batch_size (`Optional[int]`, *optional*, defaults to `None`):
Batch size to use when precomputing reference model log probabilities. This can be set higher than the
training batch size to speed up preprocessing. If `None`, defaults to `per_device_train_batch_size` for
training and `per_device_eval_batch_size` for evaluation.
dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`):
Number of processes to use for processing the dataset.
model_init_kwargs (`Optional[dict[str, Any]]`, *optional*, defaults to `None`):
Expand Down Expand Up @@ -173,6 +177,7 @@ class DPOConfig(TrainingArguments):
disable_dropout: bool = True
generate_during_eval: bool = False
precompute_ref_log_probs: bool = False
precompute_ref_batch_size: Optional[int] = None
dataset_num_proc: Optional[int] = None
model_init_kwargs: Optional[dict[str, Any]] = None
ref_model_init_kwargs: Optional[dict[str, Any]] = None
Expand Down
6 changes: 4 additions & 2 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,8 +684,9 @@ def get_train_dataloader(self) -> DataLoader:
"""

if self.precompute_ref_log_probs and not self._precomputed_train_ref_log_probs:
batch_size = self.args.precompute_ref_batch_size or self.args.per_device_train_batch_size
dataloader_params = {
"batch_size": self.args.per_device_train_batch_size,
"batch_size": batch_size,
"collate_fn": self.data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
Expand Down Expand Up @@ -737,8 +738,9 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset

if self.precompute_ref_log_probs and not self._precomputed_eval_ref_log_probs:
batch_size = self.args.precompute_ref_batch_size or self.args.per_device_eval_batch_size
dataloader_params = {
"batch_size": self.args.per_device_eval_batch_size,
"batch_size": batch_size,
"collate_fn": self.data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
Expand Down

0 comments on commit f6f4265

Please sign in to comment.