Skip to content

Commit

Permalink
Simplify logic and add prompt to eval dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
ArthurCamara committed Sep 26, 2024
1 parent c49ca90 commit 781f1d7
Showing 1 changed file with 23 additions and 9 deletions.
32 changes: 23 additions & 9 deletions sentence_transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,20 @@ def add_dataset_name_column(self, dataset_dict: DatasetDict) -> DatasetDict:
dataset_dict[key] = dataset.add_column("dataset_name", [key] * len(dataset))
return dataset_dict

def add_prompts_to_dataset(self, dataset: Dataset, prompt: str | dict[str, str]) -> Dataset:
def add_prompts_to_dataset(self, dataset: Dataset, dataset_name: str | None = None) -> Dataset:
if dataset_name is not None:
if isinstance(self.prompt, dict):
if dataset_name not in self.prompt:
raise ValueError(f"dataset_name {dataset_name} not found in self.prompt")
prompt = self.prompt[dataset_name]
else:
# If self.prompt is dict[str, dict[str, str]], raise an error.
if isinstance(self.prompt, dict) and not isinstance(list(self.prompt.keys())[0], str):
raise ValueError(
"When defining prompts and a single dataset, use either a single string for applying the prompt to all columns or a dict mapping column names to prompts."
)
prompt = self.prompt

def _add_prompts(sample):
if isinstance(prompt, dict):
for col in prompt.keys():
Expand Down Expand Up @@ -642,8 +655,9 @@ def get_train_dataloader(self) -> DataLoader:
raise ValueError(
"Sentence Transformers is not compatible with a DatasetDict containing an IterableDataset."
)
if isinstance(self.prompt, dict) and dataset_name in self.prompt:
train_dataset[dataset_name] = self.add_prompts_to_dataset(dataset, self.prompt[dataset_name])
if self.prompt is not None:
train_dataset[dataset_name] = self.add_prompts_to_dataset(dataset, dataset_name)

if isinstance(self.loss, dict):
train_dataset = self.add_dataset_name_column(train_dataset)

Expand Down Expand Up @@ -671,11 +685,7 @@ def get_train_dataloader(self) -> DataLoader:
self.validate_column_names(train_dataset)

if self.prompt is not None:
if isinstance(self.prompt, dict) and isinstance(list(self.prompt.keys())[0], dict):
raise ValueError(
"When defining prompts and a single dataset, use either a single string for applying the prompt to all columns or a dict mapping column names to prompts."
)
train_dataset = self.add_prompts_to_dataset(train_dataset, self.prompt)
train_dataset = self.add_prompts_to_dataset(train_dataset)

batch_sampler = self.get_batch_sampler(
train_dataset,
Expand Down Expand Up @@ -741,11 +751,13 @@ def get_eval_dataloader(self, eval_dataset: Dataset | None = None) -> DataLoader
)

elif isinstance(eval_dataset, DatasetDict):
for dataset in eval_dataset.values():
for dataset_name, dataset in eval_dataset.items():
if isinstance(dataset, IterableDataset):
raise ValueError(
"Sentence Transformers is not compatible with a DatasetDict containing an IterableDataset."
)
if self.prompt is not None:
eval_dataset[dataset_name] = self.add_prompts_to_dataset(dataset)
if isinstance(self.loss, dict):
eval_dataset = self.add_dataset_name_column(eval_dataset)
batch_samplers = [
Expand Down Expand Up @@ -777,6 +789,8 @@ def get_eval_dataloader(self, eval_dataset: Dataset | None = None) -> DataLoader
generator=generator,
)
dataloader_params["batch_sampler"] = batch_sampler
if self.prompt is not None:
eval_dataset = self.add_prompts_to_dataset(eval_dataset, self.prompt)

else:
raise ValueError(
Expand Down

0 comments on commit 781f1d7

Please sign in to comment.