Skip to content

Commit

Permalink
Use ..._batch_size rather than per_device_..._batch_size
Browse files Browse the repository at this point in the history
The former is identical for single-GPU and DDP, but has a higher batch size for DP (which is the expected behaviour).
  • Loading branch information
tomaarsen committed Sep 10, 2024
1 parent 6d7637a commit 5f91adb
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions sentence_transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,7 @@ def get_train_dataloader(self) -> DataLoader:
batch_samplers = [
self.get_batch_sampler(
dataset,
batch_size=self.args.per_device_train_batch_size,
batch_size=self.args.train_batch_size,
drop_last=self.args.dataloader_drop_last,
valid_label_columns=data_collator.valid_label_columns,
generator=generator,
Expand Down Expand Up @@ -644,7 +644,7 @@ def get_eval_dataloader(self, eval_dataset: Dataset | None = None) -> DataLoader
batch_samplers = [
self.get_batch_sampler(
dataset,
batch_size=self.args.per_device_eval_batch_size,
batch_size=self.args.eval_batch_size,
drop_last=self.args.dataloader_drop_last,
valid_label_columns=data_collator.valid_label_columns,
generator=generator,
Expand Down Expand Up @@ -708,7 +708,7 @@ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
batch_samplers = [
self.get_batch_sampler(
dataset,
batch_size=self.args.per_device_eval_batch_size,
batch_size=self.args.eval_batch_size,
drop_last=self.args.dataloader_drop_last,
valid_label_columns=data_collator.valid_label_columns,
generator=generator,
Expand Down

0 comments on commit 5f91adb

Please sign in to comment.