Skip to content

Commit

Permalink
add prompt to test dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
ArthurCamara committed Sep 26, 2024
1 parent 781f1d7 commit 72713b2
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions sentence_transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,6 +781,9 @@ def get_eval_dataloader(self, eval_dataset: Dataset | None = None) -> DataLoader
dataloader_params["batch_sampler"] = batch_sampler

elif isinstance(eval_dataset, Dataset):
if self.prompt is not None:
eval_dataset = self.add_prompts_to_dataset(eval_dataset)

batch_sampler = self.get_batch_sampler(
eval_dataset,
batch_size=self.args.eval_batch_size,
Expand All @@ -789,8 +792,6 @@ def get_eval_dataloader(self, eval_dataset: Dataset | None = None) -> DataLoader
generator=generator,
)
dataloader_params["batch_sampler"] = batch_sampler
if self.prompt is not None:
eval_dataset = self.add_prompts_to_dataset(eval_dataset, self.prompt)

else:
raise ValueError(
Expand Down Expand Up @@ -848,6 +849,8 @@ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
raise ValueError(
"Sentence Transformers is not compatible with a DatasetDict containing an IterableDataset."
)
if self.prompt is not None:
test_dataset[dataset_name] = self.add_prompts_to_dataset(dataset, dataset_name)
if isinstance(self.loss, dict):
test_dataset = self.add_dataset_name_column(test_dataset)
batch_samplers = [
Expand All @@ -873,6 +876,8 @@ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
elif isinstance(test_dataset, Dataset):
self.validate_column_names(test_dataset)

if self.prompt is not None:
test_dataset = self.add_prompts_to_dataset(test_dataset)
batch_sampler = self.get_batch_sampler(
test_dataset,
batch_size=self.args.eval_batch_size,
Expand Down

0 comments on commit 72713b2

Please sign in to comment.