-
Notifications
You must be signed in to change notification settings - Fork 2.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[feat
] Trainer with prompts and prompt masking
#2964
[feat
] Trainer with prompts and prompt masking
#2964
Conversation
…tence-transformers into Prompting-on-evaluators
feat
] Trainer with prompts and prompt masking
354cb65
to
bf9eb80
Compare
Hello! Thanks for this PR. I rebased it to get rid of the leftover commits that aren't necessary here.
Could we perhaps add the prompts (and prompt lengths) in the data collator? E.g. right here: https://github.com/ArthurCamara/sentence-transformers/blob/bf9eb803ce2dda26a8ef903c33d80cd1fcb55a3d/sentence_transformers/data_collator.py#L50-L56 The data collator knows the dataset name, the column name (see the snippet), and should then be able to use that information to "on the fly" prepend the prompts. In a perfect world we could even only tokenize the prompts once, but that gets complicated with padding and truncation, so it's better to keep it simpler. I also like your idea that I'm curious to hear your thoughts on this.
|
This was one of the things I was considering, to change the Collator instead of the dataset itself. But I had issues with Accelerator and DDP before when the data was not exclusively tensors (i.e., strings), but I think we can walk around it within the collator. I will give it a shot and let you know.
Agreed. =)
|
…/sentence-transformers into trainer-with-prompt-masking
86dd847
to
bf9eb80
Compare
…/sentence-transformers into trainer-with-prompt-masking
Hi thanks for implementing this. Any guide on how to fine-tune with prompts? |
Hello! Until this is integrated, I would recommend manually adding the prompts to your training datasets. E.g.: from datasets import load_dataset
from typing import Dict, List, Any
def prepend_prompt(batch: Dict[str, List[Any]], prompts: Dict[str, str] | None = None) -> Dict[str, List[Any]]:
if not prompts:
return batch
for column_name, prompt in prompts.items():
batch[column_name] = [prompt + value for value in batch[column_name]]
return batch
train_dataset = load_dataset("sentence-transformers/natural-questions", split="train")
train_dataset = train_dataset.map(
prepend_prompt,
batched=True,
fn_kwargs={"prompts": {"question": "Represent this sentence for searching relevant passages: "}}
)
print(train_dataset[0])
# {'query': 'Represent this sentence for searching relevant passages: when did richmond last play in a preliminary final', 'answer': "Richmond Football Club Richmond began 2017 with 5 straight wins, a feat it had not achieved since 1995. A series of close losses hampered the Tigers throughout the middle of the season, including a 5-point loss to the Western Bulldogs, 2-point loss to Fremantle, and a 3-point loss to the Giants. Richmond ended the season strongly with convincing victories over Fremantle and St Kilda in the final two rounds, elevating the club to 3rd on the ladder. Richmond's first final of the season against the Cats at the MCG attracted a record qualifying final crowd of 95,028; the Tigers won by 51 points. Having advanced to the first preliminary finals for the first time since 2001, Richmond defeated Greater Western Sydney by 36 points in front of a crowd of 94,258 to progress to the Grand Final against Adelaide, their first Grand Final appearance since 1982. The attendance was 100,021, the largest crowd to a grand final since 1986. The Crows led at quarter time and led by as many as 13, but the Tigers took over the game as it progressed and scored seven straight goals at one point. They eventually would win by 48 points – 16.12 (108) to Adelaide's 8.12 (60) – to end their 37-year flag drought.[22] Dustin Martin also became the first player to win a Premiership medal, the Brownlow Medal and the Norm Smith Medal in the same season, while Damien Hardwick was named AFL Coaches Association Coach of the Year. Richmond's jump from 13th to premiers also marked the biggest jump from one AFL season to the next."} And the rest is the same as the normal training: https://sbert.net/docs/sentence_transformer/training_overview.html
|
Hey thanks so much for the quick reply. My main concern here would be if pooling is being done on just the text (and excluding the prompt). I believe in the INSTRUCTOR paper they do not include the embeddings of the prompt during mean pooling. Would this solution take care of that? |
Indeed, my solution only works if you're including the prompt in the pooling. If you're not, i.e. with setting this to False:
Then you must use this PR. You can use:
and then use the regular training with one extra parameter in the trainer = SentenceTransformerTrainer(
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
loss=loss,
prompts={
"query": "Represent this sentence for searching relevant passages: ",
},
evaluator=dev_evaluator,
) The
I do want to warn you that I'm about to fully overhaul this PR, although the usage will remain the same.
|
Thanks so much. And if I was interested in training with dynamic prompts (unique prompt per sample) would that be possible with the methods you described? |
Unique per sample is not possible here without subclassing the Trainer, no. You could use a unique sample per dataset, if that helps. I didn't think that a unique prompt per sample was a notable use case, so I didn't think to integrate it. |
Got it. Thank you! |
Heya @ArthurCamara, I've overhauled the prompt prepending once more, as I still had some slight concerns with the previous implementations after some experimentation. You have worked on 2 implementations, and I'm now proposing a third as well:
I had concerns with the first two:
After getting some valuable recommendations by the Datasets team and @lhoestq in particular, I'm now using I've also trained 2 near-identical models:
The former consistently performs slightly worse than the model with the Also, the prompts model shows the prompts in the model card easily: https://huggingface.co/tomaarsen/mpnet-base-nq-prompts#natural-questions Lastly, I built an extensive training suite for this feature because there are a LOT of moving parts between training, evaluation, iterable datasets, and the various prompt formats. I'm curious about your thoughts on my proposal @ArthurCamara, as I know you're using this yourself too! And one final question:
|
Adding the prompts to the model card is something very useful that I haven't thought of. Nice.
Nice to learn something new. Didn't know about
Neat. I like the way prompting helps to disentangle the representations of query and documents even in smaller models
Good question. I want to say it should be in the Arguments, so it can be easily swapped out when testing with different configurations. But I'm not sure how of a good UX it will be to pass a double-nested dictionary as an argument to training script (of course, reading from a json/yaml file is also an option).
|
This is just safer & less hacky - I encountered a nasty bug where only returning 1 value (because we technically only need 1) results in all other samples being skipped. Not great.
This also already mentions the v3.3 release - a bit premature, but it's a tad simpler this way
Thanks a bunch for spearheading this. I didn't expect that the prompts would have such a notable impact (0.66% and 0.90% relative NDCG@10 across mpnet-base and bert-base-uncased, respectively), but I'm glad that they do. This will be included as one of the 4 major features in Monday's v3.3 release, alongside the NanoBEIREvaluator which will be another major feature. I really appreciate your work on these.
|
Pull Request overview
Trainer
classPooling
when training.Details
Currently, the
encode
method ofSentenceTransformer
supports adding prompts (or instructions) dynamically to the sentences by passing eitherprompt
orprompt_name
. However, this is not supported when training, as mentioned in #2945, as it uses theforward
method instead.This PR implements a similar functionality to the
Trainer
, by addingprompt
parameter that can be:str
: The prompt will be appended to all sentences in the datasetdict[str, str]
: If the keys are column names, it will append the prompt to the respective column. If the training dataset is a dictionary of datasets, and the dictionary keys are names of the datasets, it will add the prompt to all the columns of the respective dataset.dict[str, dict[str, str]]
: Same as above, but assumes the first level is the dataset name and the second level are the column names.As the prompts can be dynamic (changing for each dataset and column), they are injected in the sentences by the
get_train|test|eval|_dataloader
methods, by callingadd_prompts_to_dataset
, which solves for each dataset and column which prompt to inject.Finally, the
add_prompts_to_dataset
also adds<column_name>_prompt_length
columns that, when passed toPooling
method withinclude_prompt=False
, will mask the instructions properly as well. (currently this is only explicitly forInstructor
models, but can be set by the user by callingmodel.set_pooling_include_prompt(include_prompt=False)